fix cyclic dependency with go modules
This commit is contained in:
parent
a838c8bc30
commit
e17abc4d5d
@ -4,10 +4,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"go/ast"
|
|
||||||
"go/format"
|
"go/format"
|
||||||
"go/parser"
|
|
||||||
"go/token"
|
|
||||||
"go/types"
|
"go/types"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
@ -64,10 +61,8 @@ var golintInitialisms = []string{
|
|||||||
|
|
||||||
// Mocker can generate mock structs.
|
// Mocker can generate mock structs.
|
||||||
type Mocker struct {
|
type Mocker struct {
|
||||||
src string
|
srcPkg *packages.Package
|
||||||
tmpl *template.Template
|
tmpl *template.Template
|
||||||
fset *token.FileSet
|
|
||||||
pkgs map[string]*ast.Package
|
|
||||||
pkgName string
|
pkgName string
|
||||||
pkgPath string
|
pkgPath string
|
||||||
|
|
||||||
@ -76,31 +71,24 @@ type Mocker struct {
|
|||||||
|
|
||||||
// New makes a new Mocker for the specified package directory.
|
// New makes a new Mocker for the specified package directory.
|
||||||
func New(src, packageName string) (*Mocker, error) {
|
func New(src, packageName string) (*Mocker, error) {
|
||||||
fset := token.NewFileSet()
|
srcPkg, err := pkgInfoFromPath(src, packages.LoadSyntax)
|
||||||
noTestFiles := func(i os.FileInfo) bool {
|
|
||||||
return !strings.HasSuffix(i.Name(), "_test.go")
|
|
||||||
}
|
|
||||||
wd, err := os.Getwd()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to determin current working directory: %s", err)
|
return nil, fmt.Errorf("Couldn't load source package: %s", err)
|
||||||
}
|
}
|
||||||
packagePath := stripGopath(filepath.Join(wd, src, packageName))
|
pkgPath := srcPkg.PkgPath
|
||||||
|
|
||||||
pkgs, err := parser.ParseDir(fset, src, noTestFiles, parser.SpuriousErrors)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if len(packageName) == 0 {
|
if len(packageName) == 0 {
|
||||||
for pkgName := range pkgs {
|
packageName = srcPkg.Name
|
||||||
if strings.Contains(pkgName, "_test") {
|
} else {
|
||||||
continue
|
mockPkgPath := filepath.Join(src, packageName)
|
||||||
}
|
if _, err := os.Stat(mockPkgPath); os.IsNotExist(err) {
|
||||||
packageName = pkgName
|
os.Mkdir(mockPkgPath, os.ModePerm)
|
||||||
break
|
|
||||||
}
|
}
|
||||||
}
|
mockPkg, err := pkgInfoFromPath(mockPkgPath, packages.LoadFiles)
|
||||||
if len(packageName) == 0 {
|
if err != nil {
|
||||||
return nil, errors.New("failed to determine package name")
|
return nil, fmt.Errorf("Couldn't load mock package: %s", err)
|
||||||
|
}
|
||||||
|
pkgPath = mockPkg.PkgPath
|
||||||
}
|
}
|
||||||
|
|
||||||
tmpl, err := template.New("moq").Funcs(templateFuncs).Parse(moqTemplate)
|
tmpl, err := template.New("moq").Funcs(templateFuncs).Parse(moqTemplate)
|
||||||
@ -108,12 +96,10 @@ func New(src, packageName string) (*Mocker, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &Mocker{
|
return &Mocker{
|
||||||
src: src,
|
|
||||||
tmpl: tmpl,
|
tmpl: tmpl,
|
||||||
fset: fset,
|
srcPkg: srcPkg,
|
||||||
pkgs: pkgs,
|
|
||||||
pkgName: packageName,
|
pkgName: packageName,
|
||||||
pkgPath: packagePath,
|
pkgPath: pkgPath,
|
||||||
imports: make(map[string]bool),
|
imports: make(map[string]bool),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
@ -124,11 +110,6 @@ func (m *Mocker) Mock(w io.Writer, name ...string) error {
|
|||||||
return errors.New("must specify one interface")
|
return errors.New("must specify one interface")
|
||||||
}
|
}
|
||||||
|
|
||||||
pkgInfo, err := pkgInfoFromPath(m.src)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
doc := doc{
|
doc := doc{
|
||||||
PackageName: m.pkgName,
|
PackageName: m.pkgName,
|
||||||
Imports: moqImports,
|
Imports: moqImports,
|
||||||
@ -136,7 +117,7 @@ func (m *Mocker) Mock(w io.Writer, name ...string) error {
|
|||||||
|
|
||||||
mocksMethods := false
|
mocksMethods := false
|
||||||
|
|
||||||
tpkg := pkgInfo.Types
|
tpkg := m.srcPkg.Types
|
||||||
for _, n := range name {
|
for _, n := range name {
|
||||||
iface := tpkg.Scope().Lookup(n)
|
iface := tpkg.Scope().Lookup(n)
|
||||||
if iface == nil {
|
if iface == nil {
|
||||||
@ -177,7 +158,7 @@ func (m *Mocker) Mock(w io.Writer, name ...string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
err = m.tmpl.Execute(&buf, doc)
|
err := m.tmpl.Execute(&buf, doc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -228,23 +209,15 @@ func (m *Mocker) extractArgs(sig *types.Signature, list *types.Tuple, nameFormat
|
|||||||
return params
|
return params
|
||||||
}
|
}
|
||||||
|
|
||||||
func pkgInfoFromPath(src string) (*packages.Package, error) {
|
func pkgInfoFromPath(src string, mode packages.LoadMode) (*packages.Package, error) {
|
||||||
abs, err := filepath.Abs(src)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
pkgFull := stripGopath(abs)
|
|
||||||
|
|
||||||
conf := packages.Config{
|
conf := packages.Config{
|
||||||
Mode: packages.LoadSyntax,
|
Mode: mode,
|
||||||
Dir: src,
|
Dir: src,
|
||||||
}
|
}
|
||||||
|
pkgs, err := packages.Load(&conf)
|
||||||
pkgs, err := packages.Load(&conf, pkgFull)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(pkgs) == 0 {
|
if len(pkgs) == 0 {
|
||||||
return nil, errors.New("No packages found")
|
return nil, errors.New("No packages found")
|
||||||
}
|
}
|
||||||
|
129
pkg/moq/moq_modules_test.go
Normal file
129
pkg/moq/moq_modules_test.go
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
// +build go1.11
|
||||||
|
|
||||||
|
package moq
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// copy copies srcPath to destPath, dirs and files
|
||||||
|
func copy(srcPath, destPath string, item os.FileInfo) error {
|
||||||
|
if item.IsDir() {
|
||||||
|
if err := os.MkdirAll(destPath, os.FileMode(0755)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
items, err := ioutil.ReadDir(srcPath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, item := range items {
|
||||||
|
src := filepath.Join(srcPath, item.Name())
|
||||||
|
dest := filepath.Join(destPath, item.Name())
|
||||||
|
if err := copy(src, dest, item); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
src, err := os.Open(srcPath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer src.Close()
|
||||||
|
|
||||||
|
dest, err := os.Create(destPath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer dest.Close()
|
||||||
|
|
||||||
|
_, err = io.Copy(dest, src)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// copyTestPackage copies test package to a temporary directory
|
||||||
|
func copyTestPackage(srcPath string) (string, error) {
|
||||||
|
tmpDir, err := ioutil.TempDir("", "moq-tests")
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
info, err := os.Lstat(srcPath)
|
||||||
|
if err != nil {
|
||||||
|
return tmpDir, err
|
||||||
|
}
|
||||||
|
return tmpDir, copy(srcPath, tmpDir, info)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModulesSamePackage(t *testing.T) {
|
||||||
|
tmpDir, err := copyTestPackage("testpackages/modules")
|
||||||
|
defer os.RemoveAll(tmpDir)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Test package copy error: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := New(tmpDir, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("moq.New: %s", err)
|
||||||
|
}
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err = m.Mock(&buf, "Foo")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("m.Mock: %s", err)
|
||||||
|
}
|
||||||
|
s := buf.String()
|
||||||
|
if strings.Contains(s, `github.com/matryer/modules`) {
|
||||||
|
t.Errorf("should not have cyclic dependency")
|
||||||
|
}
|
||||||
|
// assertions of things that should be mentioned
|
||||||
|
var strs = []string{
|
||||||
|
"package simple",
|
||||||
|
"var _ Foo = &FooMock{}",
|
||||||
|
"type FooMock struct",
|
||||||
|
}
|
||||||
|
for _, str := range strs {
|
||||||
|
if !strings.Contains(s, str) {
|
||||||
|
t.Errorf("expected but missing: \"%s\"", str)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func TestModulesNestedPackage(t *testing.T) {
|
||||||
|
tmpDir, err := copyTestPackage("testpackages/modules")
|
||||||
|
defer os.RemoveAll(tmpDir)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Test package copy error: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := New(tmpDir, "nested")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("moq.New: %s", err)
|
||||||
|
}
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err = m.Mock(&buf, "Foo")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("m.Mock: %s", err)
|
||||||
|
}
|
||||||
|
s := buf.String()
|
||||||
|
// assertions of things that should be mentioned
|
||||||
|
var strs = []string{
|
||||||
|
"package nested",
|
||||||
|
"github.com/matryer/modules",
|
||||||
|
"var _ simple.Foo = &FooMock{}",
|
||||||
|
"type FooMock struct",
|
||||||
|
"func (mock *FooMock) FooIt(bar *simple.Bar) {",
|
||||||
|
}
|
||||||
|
for _, str := range strs {
|
||||||
|
if !strings.Contains(s, str) {
|
||||||
|
t.Errorf("expected but missing: \"%s\"", str)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
1
pkg/moq/testpackages/modules/go.mod
Normal file
1
pkg/moq/testpackages/modules/go.mod
Normal file
@ -0,0 +1 @@
|
|||||||
|
module github.com/matryer/modules
|
7
pkg/moq/testpackages/modules/simple.go
Normal file
7
pkg/moq/testpackages/modules/simple.go
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
package simple
|
||||||
|
|
||||||
|
type Foo interface {
|
||||||
|
FooIt(bar *Bar)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Bar struct{}
|
Loading…
Reference in New Issue
Block a user