fix cyclic dependency with go modules

This commit is contained in:
Ivan Safonov 2019-02-28 23:49:51 +07:00
parent a838c8bc30
commit e17abc4d5d
4 changed files with 158 additions and 48 deletions

View File

@ -4,10 +4,7 @@ import (
"bytes" "bytes"
"errors" "errors"
"fmt" "fmt"
"go/ast"
"go/format" "go/format"
"go/parser"
"go/token"
"go/types" "go/types"
"io" "io"
"os" "os"
@ -64,10 +61,8 @@ var golintInitialisms = []string{
// Mocker can generate mock structs. // Mocker can generate mock structs.
type Mocker struct { type Mocker struct {
src string srcPkg *packages.Package
tmpl *template.Template tmpl *template.Template
fset *token.FileSet
pkgs map[string]*ast.Package
pkgName string pkgName string
pkgPath string pkgPath string
@ -76,31 +71,24 @@ type Mocker struct {
// New makes a new Mocker for the specified package directory. // New makes a new Mocker for the specified package directory.
func New(src, packageName string) (*Mocker, error) { func New(src, packageName string) (*Mocker, error) {
fset := token.NewFileSet() srcPkg, err := pkgInfoFromPath(src, packages.LoadSyntax)
noTestFiles := func(i os.FileInfo) bool {
return !strings.HasSuffix(i.Name(), "_test.go")
}
wd, err := os.Getwd()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to determin current working directory: %s", err) return nil, fmt.Errorf("Couldn't load source package: %s", err)
} }
packagePath := stripGopath(filepath.Join(wd, src, packageName)) pkgPath := srcPkg.PkgPath
pkgs, err := parser.ParseDir(fset, src, noTestFiles, parser.SpuriousErrors) if len(packageName) == 0 {
packageName = srcPkg.Name
} else {
mockPkgPath := filepath.Join(src, packageName)
if _, err := os.Stat(mockPkgPath); os.IsNotExist(err) {
os.Mkdir(mockPkgPath, os.ModePerm)
}
mockPkg, err := pkgInfoFromPath(mockPkgPath, packages.LoadFiles)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("Couldn't load mock package: %s", err)
} }
if len(packageName) == 0 { pkgPath = mockPkg.PkgPath
for pkgName := range pkgs {
if strings.Contains(pkgName, "_test") {
continue
}
packageName = pkgName
break
}
}
if len(packageName) == 0 {
return nil, errors.New("failed to determine package name")
} }
tmpl, err := template.New("moq").Funcs(templateFuncs).Parse(moqTemplate) tmpl, err := template.New("moq").Funcs(templateFuncs).Parse(moqTemplate)
@ -108,12 +96,10 @@ func New(src, packageName string) (*Mocker, error) {
return nil, err return nil, err
} }
return &Mocker{ return &Mocker{
src: src,
tmpl: tmpl, tmpl: tmpl,
fset: fset, srcPkg: srcPkg,
pkgs: pkgs,
pkgName: packageName, pkgName: packageName,
pkgPath: packagePath, pkgPath: pkgPath,
imports: make(map[string]bool), imports: make(map[string]bool),
}, nil }, nil
} }
@ -124,11 +110,6 @@ func (m *Mocker) Mock(w io.Writer, name ...string) error {
return errors.New("must specify one interface") return errors.New("must specify one interface")
} }
pkgInfo, err := pkgInfoFromPath(m.src)
if err != nil {
return err
}
doc := doc{ doc := doc{
PackageName: m.pkgName, PackageName: m.pkgName,
Imports: moqImports, Imports: moqImports,
@ -136,7 +117,7 @@ func (m *Mocker) Mock(w io.Writer, name ...string) error {
mocksMethods := false mocksMethods := false
tpkg := pkgInfo.Types tpkg := m.srcPkg.Types
for _, n := range name { for _, n := range name {
iface := tpkg.Scope().Lookup(n) iface := tpkg.Scope().Lookup(n)
if iface == nil { if iface == nil {
@ -177,7 +158,7 @@ func (m *Mocker) Mock(w io.Writer, name ...string) error {
} }
var buf bytes.Buffer var buf bytes.Buffer
err = m.tmpl.Execute(&buf, doc) err := m.tmpl.Execute(&buf, doc)
if err != nil { if err != nil {
return err return err
} }
@ -228,23 +209,15 @@ func (m *Mocker) extractArgs(sig *types.Signature, list *types.Tuple, nameFormat
return params return params
} }
func pkgInfoFromPath(src string) (*packages.Package, error) { func pkgInfoFromPath(src string, mode packages.LoadMode) (*packages.Package, error) {
abs, err := filepath.Abs(src)
if err != nil {
return nil, err
}
pkgFull := stripGopath(abs)
conf := packages.Config{ conf := packages.Config{
Mode: packages.LoadSyntax, Mode: mode,
Dir: src, Dir: src,
} }
pkgs, err := packages.Load(&conf)
pkgs, err := packages.Load(&conf, pkgFull)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(pkgs) == 0 { if len(pkgs) == 0 {
return nil, errors.New("No packages found") return nil, errors.New("No packages found")
} }

129
pkg/moq/moq_modules_test.go Normal file
View File

@ -0,0 +1,129 @@
// +build go1.11
package moq
import (
"bytes"
"io"
"io/ioutil"
"os"
"path/filepath"
"strings"
"testing"
)
// copy copies srcPath to destPath, dirs and files
func copy(srcPath, destPath string, item os.FileInfo) error {
if item.IsDir() {
if err := os.MkdirAll(destPath, os.FileMode(0755)); err != nil {
return err
}
items, err := ioutil.ReadDir(srcPath)
if err != nil {
return err
}
for _, item := range items {
src := filepath.Join(srcPath, item.Name())
dest := filepath.Join(destPath, item.Name())
if err := copy(src, dest, item); err != nil {
return err
}
}
} else {
src, err := os.Open(srcPath)
if err != nil {
return err
}
defer src.Close()
dest, err := os.Create(destPath)
if err != nil {
return err
}
defer dest.Close()
_, err = io.Copy(dest, src)
if err != nil {
return err
}
}
return nil
}
// copyTestPackage copies test package to a temporary directory
func copyTestPackage(srcPath string) (string, error) {
tmpDir, err := ioutil.TempDir("", "moq-tests")
if err != nil {
return "", err
}
info, err := os.Lstat(srcPath)
if err != nil {
return tmpDir, err
}
return tmpDir, copy(srcPath, tmpDir, info)
}
func TestModulesSamePackage(t *testing.T) {
tmpDir, err := copyTestPackage("testpackages/modules")
defer os.RemoveAll(tmpDir)
if err != nil {
t.Fatalf("Test package copy error: %s", err)
}
m, err := New(tmpDir, "")
if err != nil {
t.Fatalf("moq.New: %s", err)
}
var buf bytes.Buffer
err = m.Mock(&buf, "Foo")
if err != nil {
t.Errorf("m.Mock: %s", err)
}
s := buf.String()
if strings.Contains(s, `github.com/matryer/modules`) {
t.Errorf("should not have cyclic dependency")
}
// assertions of things that should be mentioned
var strs = []string{
"package simple",
"var _ Foo = &FooMock{}",
"type FooMock struct",
}
for _, str := range strs {
if !strings.Contains(s, str) {
t.Errorf("expected but missing: \"%s\"", str)
}
}
}
func TestModulesNestedPackage(t *testing.T) {
tmpDir, err := copyTestPackage("testpackages/modules")
defer os.RemoveAll(tmpDir)
if err != nil {
t.Fatalf("Test package copy error: %s", err)
}
m, err := New(tmpDir, "nested")
if err != nil {
t.Fatalf("moq.New: %s", err)
}
var buf bytes.Buffer
err = m.Mock(&buf, "Foo")
if err != nil {
t.Errorf("m.Mock: %s", err)
}
s := buf.String()
// assertions of things that should be mentioned
var strs = []string{
"package nested",
"github.com/matryer/modules",
"var _ simple.Foo = &FooMock{}",
"type FooMock struct",
"func (mock *FooMock) FooIt(bar *simple.Bar) {",
}
for _, str := range strs {
if !strings.Contains(s, str) {
t.Errorf("expected but missing: \"%s\"", str)
}
}
}

View File

@ -0,0 +1 @@
module github.com/matryer/modules

View File

@ -0,0 +1,7 @@
package simple
type Foo interface {
FooIt(bar *Bar)
}
type Bar struct{}