diff --git a/pkg/moq/importer.go b/pkg/moq/importer.go index 66c6109..acf1225 100644 --- a/pkg/moq/importer.go +++ b/pkg/moq/importer.go @@ -18,6 +18,7 @@ import ( ) type customImporter struct { + source string imported map[string]*types.Package base types.Importer skipTestFiles bool @@ -43,8 +44,13 @@ func (i *customImporter) Import(path string) (*types.Package, error) { return pkg, nil } -func gopathDir(pkg string) (string, error) { - for _, gopath := range strings.Split(os.Getenv("GOPATH"), ":") { +func gopathDir(source, pkg string) (string, error) { + // check vendor directory + vendorPath, found := vendorPath(source, pkg) + if found { + return vendorPath, nil + } + for _, gopath := range gopaths() { absPath, err := filepath.Abs(path.Join(gopath, "src", pkg)) if err != nil { return "", err @@ -53,18 +59,50 @@ func gopathDir(pkg string) (string, error) { return absPath, nil } } - return "", fmt.Errorf("%s not in $GOPATH", pkg) + return "", fmt.Errorf("%s not in $GOPATH or %s", pkg, path.Join(source, "vendor")) +} + +func vendorPath(source, pkg string) (string, bool) { + for { + if isGopath(source) { + return "", false + } + vendorPath, err := filepath.Abs(path.Join(source, "vendor", pkg)) + if err != nil { + return "", false + } + if dir, err := os.Stat(vendorPath); err == nil && dir.IsDir() { + return vendorPath, true + } + source = filepath.Dir(source) + if source == "." { + return "", false + } + } } func removeGopath(p string) string { - for _, gopath := range strings.Split(os.Getenv("GOPATH"), ":") { + for _, gopath := range gopaths() { p = strings.Replace(p, path.Join(gopath, "src")+"/", "", 1) } return p } +func gopaths() []string { + return strings.Split(os.Getenv("GOPATH"), string(filepath.ListSeparator)) +} + +func isGopath(path string) bool { + for _, p := range gopaths() { + if p == path { + return true + } + } + return false +} + func (i *customImporter) fsPkg(pkg string) (*types.Package, error) { - dir, err := gopathDir(pkg) + dir, err := gopathDir(i.source, pkg) if err != nil { return importOrErr(i.base, pkg, err) } @@ -118,8 +156,9 @@ func importOrErr(base types.Importer, pkg string, err error) (*types.Package, er } // newImporter returns an importer that will try to import code from gopath before using go/importer.Default and skipping test files -func newImporter() types.Importer { +func newImporter(source string) types.Importer { return &customImporter{ + source: source, imported: make(map[string]*types.Package), base: importer.Default(), skipTestFiles: true, @@ -138,8 +177,8 @@ func newImporter() types.Importer { // stripGopath teks the directory to a package and remove the gopath to get the // cannonical package name func stripGopath(p string) string { - for _, gopath := range strings.Split(os.Getenv("GOPATH"), ":") { - p = strings.Replace(p, path.Join(gopath, "src")+"/", "", 1) + for _, gopath := range gopaths() { + p = strings.TrimPrefix(p, path.Join(gopath, "src")+"/") } return p } diff --git a/pkg/moq/moq.go b/pkg/moq/moq.go index c2339ce..a006121 100644 --- a/pkg/moq/moq.go +++ b/pkg/moq/moq.go @@ -76,7 +76,7 @@ func (m *Mocker) Mock(w io.Writer, name ...string) error { files[i] = file i++ } - conf := types.Config{Importer: newImporter()} + conf := types.Config{Importer: newImporter(m.src)} tpkg, err := conf.Check(m.src, m.fset, files, nil) if err != nil { return err diff --git a/pkg/moq/moq_test.go b/pkg/moq/moq_test.go index bfdab08..d54d70b 100644 --- a/pkg/moq/moq_test.go +++ b/pkg/moq/moq_test.go @@ -2,7 +2,6 @@ package moq import ( "bytes" - "log" "strings" "testing" ) @@ -183,5 +182,14 @@ func TestVendoredPackages(t *testing.T) { if err != nil { t.Errorf("mock error: %s", err) } - log.Println(buf.String()) + s := buf.String() + // assertions of things that should be mentioned + var strs = []string{ + `"github.com/matryer/somerepo"`, + } + for _, str := range strs { + if !strings.Contains(s, str) { + t.Errorf("expected but missing: \"%s\"", str) + } + } } diff --git a/pkg/moq/testpackages/vendoring/user/vendor/github.com/matryer/somerepo/code.go b/pkg/moq/testpackages/vendoring/vendor/github.com/matryer/somerepo/code.go similarity index 100% rename from pkg/moq/testpackages/vendoring/user/vendor/github.com/matryer/somerepo/code.go rename to pkg/moq/testpackages/vendoring/vendor/github.com/matryer/somerepo/code.go