diff --git a/pkg/moq/moq.go b/pkg/moq/moq.go index 3479979..a29d6ab 100644 --- a/pkg/moq/moq.go +++ b/pkg/moq/moq.go @@ -4,10 +4,7 @@ import ( "bytes" "errors" "fmt" - "go/ast" "go/format" - "go/parser" - "go/token" "go/types" "io" "os" @@ -64,10 +61,8 @@ var golintInitialisms = []string{ // Mocker can generate mock structs. type Mocker struct { - src string + srcPkg *packages.Package tmpl *template.Template - fset *token.FileSet - pkgs map[string]*ast.Package pkgName string pkgPath string @@ -76,31 +71,24 @@ type Mocker struct { // New makes a new Mocker for the specified package directory. func New(src, packageName string) (*Mocker, error) { - fset := token.NewFileSet() - noTestFiles := func(i os.FileInfo) bool { - return !strings.HasSuffix(i.Name(), "_test.go") - } - wd, err := os.Getwd() + srcPkg, err := pkgInfoFromPath(src, packages.LoadSyntax) if err != nil { - return nil, fmt.Errorf("failed to determin current working directory: %s", err) + return nil, fmt.Errorf("Couldn't load source package: %s", err) } - packagePath := stripGopath(filepath.Join(wd, src, packageName)) + pkgPath := srcPkg.PkgPath - pkgs, err := parser.ParseDir(fset, src, noTestFiles, parser.SpuriousErrors) - if err != nil { - return nil, err - } if len(packageName) == 0 { - for pkgName := range pkgs { - if strings.Contains(pkgName, "_test") { - continue - } - packageName = pkgName - break + packageName = srcPkg.Name + } else { + mockPkgPath := filepath.Join(src, packageName) + if _, err := os.Stat(mockPkgPath); os.IsNotExist(err) { + os.Mkdir(mockPkgPath, os.ModePerm) } - } - if len(packageName) == 0 { - return nil, errors.New("failed to determine package name") + mockPkg, err := pkgInfoFromPath(mockPkgPath, packages.LoadFiles) + if err != nil { + return nil, fmt.Errorf("Couldn't load mock package: %s", err) + } + pkgPath = mockPkg.PkgPath } tmpl, err := template.New("moq").Funcs(templateFuncs).Parse(moqTemplate) @@ -108,12 +96,10 @@ func New(src, packageName string) (*Mocker, error) { return nil, err } return &Mocker{ - src: src, tmpl: tmpl, - fset: fset, - pkgs: pkgs, + srcPkg: srcPkg, pkgName: packageName, - pkgPath: packagePath, + pkgPath: pkgPath, imports: make(map[string]bool), }, nil } @@ -124,11 +110,6 @@ func (m *Mocker) Mock(w io.Writer, name ...string) error { return errors.New("must specify one interface") } - pkgInfo, err := pkgInfoFromPath(m.src) - if err != nil { - return err - } - doc := doc{ PackageName: m.pkgName, Imports: moqImports, @@ -136,7 +117,7 @@ func (m *Mocker) Mock(w io.Writer, name ...string) error { mocksMethods := false - tpkg := pkgInfo.Types + tpkg := m.srcPkg.Types for _, n := range name { iface := tpkg.Scope().Lookup(n) if iface == nil { @@ -177,7 +158,7 @@ func (m *Mocker) Mock(w io.Writer, name ...string) error { } var buf bytes.Buffer - err = m.tmpl.Execute(&buf, doc) + err := m.tmpl.Execute(&buf, doc) if err != nil { return err } @@ -228,23 +209,15 @@ func (m *Mocker) extractArgs(sig *types.Signature, list *types.Tuple, nameFormat return params } -func pkgInfoFromPath(src string) (*packages.Package, error) { - abs, err := filepath.Abs(src) - if err != nil { - return nil, err - } - pkgFull := stripGopath(abs) - +func pkgInfoFromPath(src string, mode packages.LoadMode) (*packages.Package, error) { conf := packages.Config{ - Mode: packages.LoadSyntax, + Mode: mode, Dir: src, } - - pkgs, err := packages.Load(&conf, pkgFull) + pkgs, err := packages.Load(&conf) if err != nil { return nil, err } - if len(pkgs) == 0 { return nil, errors.New("No packages found") } diff --git a/pkg/moq/moq_modules_test.go b/pkg/moq/moq_modules_test.go new file mode 100644 index 0000000..1b6af9d --- /dev/null +++ b/pkg/moq/moq_modules_test.go @@ -0,0 +1,129 @@ +// +build go1.11 + +package moq + +import ( + "bytes" + "io" + "io/ioutil" + "os" + "path/filepath" + "strings" + "testing" +) + +// copy copies srcPath to destPath, dirs and files +func copy(srcPath, destPath string, item os.FileInfo) error { + if item.IsDir() { + if err := os.MkdirAll(destPath, os.FileMode(0755)); err != nil { + return err + } + items, err := ioutil.ReadDir(srcPath) + if err != nil { + return err + } + for _, item := range items { + src := filepath.Join(srcPath, item.Name()) + dest := filepath.Join(destPath, item.Name()) + if err := copy(src, dest, item); err != nil { + return err + } + } + } else { + src, err := os.Open(srcPath) + if err != nil { + return err + } + defer src.Close() + + dest, err := os.Create(destPath) + if err != nil { + return err + } + defer dest.Close() + + _, err = io.Copy(dest, src) + if err != nil { + return err + } + } + return nil +} + +// copyTestPackage copies test package to a temporary directory +func copyTestPackage(srcPath string) (string, error) { + tmpDir, err := ioutil.TempDir("", "moq-tests") + if err != nil { + return "", err + } + + info, err := os.Lstat(srcPath) + if err != nil { + return tmpDir, err + } + return tmpDir, copy(srcPath, tmpDir, info) +} + +func TestModulesSamePackage(t *testing.T) { + tmpDir, err := copyTestPackage("testpackages/modules") + defer os.RemoveAll(tmpDir) + if err != nil { + t.Fatalf("Test package copy error: %s", err) + } + + m, err := New(tmpDir, "") + if err != nil { + t.Fatalf("moq.New: %s", err) + } + var buf bytes.Buffer + err = m.Mock(&buf, "Foo") + if err != nil { + t.Errorf("m.Mock: %s", err) + } + s := buf.String() + if strings.Contains(s, `github.com/matryer/modules`) { + t.Errorf("should not have cyclic dependency") + } + // assertions of things that should be mentioned + var strs = []string{ + "package simple", + "var _ Foo = &FooMock{}", + "type FooMock struct", + } + for _, str := range strs { + if !strings.Contains(s, str) { + t.Errorf("expected but missing: \"%s\"", str) + } + } +} +func TestModulesNestedPackage(t *testing.T) { + tmpDir, err := copyTestPackage("testpackages/modules") + defer os.RemoveAll(tmpDir) + if err != nil { + t.Fatalf("Test package copy error: %s", err) + } + + m, err := New(tmpDir, "nested") + if err != nil { + t.Fatalf("moq.New: %s", err) + } + var buf bytes.Buffer + err = m.Mock(&buf, "Foo") + if err != nil { + t.Errorf("m.Mock: %s", err) + } + s := buf.String() + // assertions of things that should be mentioned + var strs = []string{ + "package nested", + "github.com/matryer/modules", + "var _ simple.Foo = &FooMock{}", + "type FooMock struct", + "func (mock *FooMock) FooIt(bar *simple.Bar) {", + } + for _, str := range strs { + if !strings.Contains(s, str) { + t.Errorf("expected but missing: \"%s\"", str) + } + } +} diff --git a/pkg/moq/testpackages/modules/go.mod b/pkg/moq/testpackages/modules/go.mod new file mode 100644 index 0000000..1562005 --- /dev/null +++ b/pkg/moq/testpackages/modules/go.mod @@ -0,0 +1 @@ +module github.com/matryer/modules diff --git a/pkg/moq/testpackages/modules/simple.go b/pkg/moq/testpackages/modules/simple.go new file mode 100644 index 0000000..71b4880 --- /dev/null +++ b/pkg/moq/testpackages/modules/simple.go @@ -0,0 +1,9 @@ +package simple + +// Foo is a test interface +type Foo interface { + FooIt(bar *Bar) +} + +// Bar is a test type +type Bar struct{}