Added customer importer
This commit is contained in:
parent
be50dca16d
commit
1bd8336f70
145
package/moq/importer.go
Normal file
145
package/moq/importer.go
Normal file
@ -0,0 +1,145 @@
|
|||||||
|
package moq
|
||||||
|
|
||||||
|
// taken from https://github.com/ernesto-jimenez/gogen
|
||||||
|
// Copyright (c) 2015 Ernesto Jiménez
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"go/ast"
|
||||||
|
"go/importer"
|
||||||
|
"go/parser"
|
||||||
|
"go/token"
|
||||||
|
"go/types"
|
||||||
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
|
"path"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type customImporter struct {
|
||||||
|
imported map[string]*types.Package
|
||||||
|
base types.Importer
|
||||||
|
skipTestFiles bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *customImporter) Import(path string) (*types.Package, error) {
|
||||||
|
var err error
|
||||||
|
if path == "" || path[0] == '.' {
|
||||||
|
path, err = filepath.Abs(filepath.Clean(path))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
path = stripGopath(path)
|
||||||
|
}
|
||||||
|
if pkg, ok := i.imported[path]; ok {
|
||||||
|
return pkg, nil
|
||||||
|
}
|
||||||
|
pkg, err := i.fsPkg(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
i.imported[path] = pkg
|
||||||
|
return pkg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func gopathDir(pkg string) (string, error) {
|
||||||
|
for _, gopath := range strings.Split(os.Getenv("GOPATH"), ":") {
|
||||||
|
absPath, err := filepath.Abs(path.Join(gopath, "src", pkg))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if dir, err := os.Stat(absPath); err == nil && dir.IsDir() {
|
||||||
|
return absPath, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("%s not in $GOPATH", pkg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeGopath(p string) string {
|
||||||
|
for _, gopath := range strings.Split(os.Getenv("GOPATH"), ":") {
|
||||||
|
p = strings.Replace(p, path.Join(gopath, "src")+"/", "", 1)
|
||||||
|
}
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *customImporter) fsPkg(pkg string) (*types.Package, error) {
|
||||||
|
dir, err := gopathDir(pkg)
|
||||||
|
if err != nil {
|
||||||
|
return importOrErr(i.base, pkg, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
dirFiles, err := ioutil.ReadDir(dir)
|
||||||
|
if err != nil {
|
||||||
|
return importOrErr(i.base, pkg, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fset := token.NewFileSet()
|
||||||
|
var files []*ast.File
|
||||||
|
for _, fileInfo := range dirFiles {
|
||||||
|
if fileInfo.IsDir() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
n := fileInfo.Name()
|
||||||
|
if path.Ext(fileInfo.Name()) != ".go" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if i.skipTestFiles && strings.Contains(fileInfo.Name(), "_test.go") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
file := path.Join(dir, n)
|
||||||
|
src, err := ioutil.ReadFile(file)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
f, err := parser.ParseFile(fset, file, src, 0)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
files = append(files, f)
|
||||||
|
}
|
||||||
|
conf := types.Config{
|
||||||
|
Importer: i,
|
||||||
|
}
|
||||||
|
p, err := conf.Check(pkg, fset, files, nil)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return importOrErr(i.base, pkg, err)
|
||||||
|
}
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func importOrErr(base types.Importer, pkg string, err error) (*types.Package, error) {
|
||||||
|
p, impErr := base.Import(pkg)
|
||||||
|
if impErr != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// newImporter returns an importer that will try to import code from gopath before using go/importer.Default and skipping test files
|
||||||
|
func newImporter() types.Importer {
|
||||||
|
return &customImporter{
|
||||||
|
imported: make(map[string]*types.Package),
|
||||||
|
base: importer.Default(),
|
||||||
|
skipTestFiles: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// // DefaultWithTestFiles same as Default but it parses test files too
|
||||||
|
// func DefaultWithTestFiles() types.Importer {
|
||||||
|
// return &customImporter{
|
||||||
|
// imported: make(map[string]*types.Package),
|
||||||
|
// base: importer.Default(),
|
||||||
|
// skipTestFiles: false,
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// stripGopath teks the directory to a package and remove the gopath to get the
|
||||||
|
// cannonical package name
|
||||||
|
func stripGopath(p string) string {
|
||||||
|
for _, gopath := range strings.Split(os.Getenv("GOPATH"), ":") {
|
||||||
|
p = strings.Replace(p, path.Join(gopath, "src")+"/", "", 1)
|
||||||
|
}
|
||||||
|
return p
|
||||||
|
}
|
@ -4,7 +4,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"go/ast"
|
"go/ast"
|
||||||
"go/importer"
|
|
||||||
"go/parser"
|
"go/parser"
|
||||||
"go/token"
|
"go/token"
|
||||||
"go/types"
|
"go/types"
|
||||||
@ -77,7 +76,7 @@ func (m *Mocker) Mock(w io.Writer, name ...string) error {
|
|||||||
files[i] = file
|
files[i] = file
|
||||||
i++
|
i++
|
||||||
}
|
}
|
||||||
conf := types.Config{Importer: importer.Default()}
|
conf := types.Config{Importer: newImporter()}
|
||||||
tpkg, err := conf.Check(m.src, m.fset, files, nil)
|
tpkg, err := conf.Check(m.src, m.fset, files, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
Loading…
Reference in New Issue
Block a user