fix cyclic dependency with go modules
This commit is contained in:
parent
a838c8bc30
commit
e17abc4d5d
@ -4,10 +4,7 @@ import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"go/format"
|
||||
"go/parser"
|
||||
"go/token"
|
||||
"go/types"
|
||||
"io"
|
||||
"os"
|
||||
@ -64,10 +61,8 @@ var golintInitialisms = []string{
|
||||
|
||||
// Mocker can generate mock structs.
|
||||
type Mocker struct {
|
||||
src string
|
||||
srcPkg *packages.Package
|
||||
tmpl *template.Template
|
||||
fset *token.FileSet
|
||||
pkgs map[string]*ast.Package
|
||||
pkgName string
|
||||
pkgPath string
|
||||
|
||||
@ -76,31 +71,24 @@ type Mocker struct {
|
||||
|
||||
// New makes a new Mocker for the specified package directory.
|
||||
func New(src, packageName string) (*Mocker, error) {
|
||||
fset := token.NewFileSet()
|
||||
noTestFiles := func(i os.FileInfo) bool {
|
||||
return !strings.HasSuffix(i.Name(), "_test.go")
|
||||
}
|
||||
wd, err := os.Getwd()
|
||||
srcPkg, err := pkgInfoFromPath(src, packages.LoadSyntax)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to determin current working directory: %s", err)
|
||||
return nil, fmt.Errorf("Couldn't load source package: %s", err)
|
||||
}
|
||||
packagePath := stripGopath(filepath.Join(wd, src, packageName))
|
||||
pkgPath := srcPkg.PkgPath
|
||||
|
||||
pkgs, err := parser.ParseDir(fset, src, noTestFiles, parser.SpuriousErrors)
|
||||
if len(packageName) == 0 {
|
||||
packageName = srcPkg.Name
|
||||
} else {
|
||||
mockPkgPath := filepath.Join(src, packageName)
|
||||
if _, err := os.Stat(mockPkgPath); os.IsNotExist(err) {
|
||||
os.Mkdir(mockPkgPath, os.ModePerm)
|
||||
}
|
||||
mockPkg, err := pkgInfoFromPath(mockPkgPath, packages.LoadFiles)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("Couldn't load mock package: %s", err)
|
||||
}
|
||||
if len(packageName) == 0 {
|
||||
for pkgName := range pkgs {
|
||||
if strings.Contains(pkgName, "_test") {
|
||||
continue
|
||||
}
|
||||
packageName = pkgName
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(packageName) == 0 {
|
||||
return nil, errors.New("failed to determine package name")
|
||||
pkgPath = mockPkg.PkgPath
|
||||
}
|
||||
|
||||
tmpl, err := template.New("moq").Funcs(templateFuncs).Parse(moqTemplate)
|
||||
@ -108,12 +96,10 @@ func New(src, packageName string) (*Mocker, error) {
|
||||
return nil, err
|
||||
}
|
||||
return &Mocker{
|
||||
src: src,
|
||||
tmpl: tmpl,
|
||||
fset: fset,
|
||||
pkgs: pkgs,
|
||||
srcPkg: srcPkg,
|
||||
pkgName: packageName,
|
||||
pkgPath: packagePath,
|
||||
pkgPath: pkgPath,
|
||||
imports: make(map[string]bool),
|
||||
}, nil
|
||||
}
|
||||
@ -124,11 +110,6 @@ func (m *Mocker) Mock(w io.Writer, name ...string) error {
|
||||
return errors.New("must specify one interface")
|
||||
}
|
||||
|
||||
pkgInfo, err := pkgInfoFromPath(m.src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
doc := doc{
|
||||
PackageName: m.pkgName,
|
||||
Imports: moqImports,
|
||||
@ -136,7 +117,7 @@ func (m *Mocker) Mock(w io.Writer, name ...string) error {
|
||||
|
||||
mocksMethods := false
|
||||
|
||||
tpkg := pkgInfo.Types
|
||||
tpkg := m.srcPkg.Types
|
||||
for _, n := range name {
|
||||
iface := tpkg.Scope().Lookup(n)
|
||||
if iface == nil {
|
||||
@ -177,7 +158,7 @@ func (m *Mocker) Mock(w io.Writer, name ...string) error {
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = m.tmpl.Execute(&buf, doc)
|
||||
err := m.tmpl.Execute(&buf, doc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -228,23 +209,15 @@ func (m *Mocker) extractArgs(sig *types.Signature, list *types.Tuple, nameFormat
|
||||
return params
|
||||
}
|
||||
|
||||
func pkgInfoFromPath(src string) (*packages.Package, error) {
|
||||
abs, err := filepath.Abs(src)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pkgFull := stripGopath(abs)
|
||||
|
||||
func pkgInfoFromPath(src string, mode packages.LoadMode) (*packages.Package, error) {
|
||||
conf := packages.Config{
|
||||
Mode: packages.LoadSyntax,
|
||||
Mode: mode,
|
||||
Dir: src,
|
||||
}
|
||||
|
||||
pkgs, err := packages.Load(&conf, pkgFull)
|
||||
pkgs, err := packages.Load(&conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(pkgs) == 0 {
|
||||
return nil, errors.New("No packages found")
|
||||
}
|
||||
|
129
pkg/moq/moq_modules_test.go
Normal file
129
pkg/moq/moq_modules_test.go
Normal file
@ -0,0 +1,129 @@
|
||||
// +build go1.11
|
||||
|
||||
package moq
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// copy copies srcPath to destPath, dirs and files
|
||||
func copy(srcPath, destPath string, item os.FileInfo) error {
|
||||
if item.IsDir() {
|
||||
if err := os.MkdirAll(destPath, os.FileMode(0755)); err != nil {
|
||||
return err
|
||||
}
|
||||
items, err := ioutil.ReadDir(srcPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, item := range items {
|
||||
src := filepath.Join(srcPath, item.Name())
|
||||
dest := filepath.Join(destPath, item.Name())
|
||||
if err := copy(src, dest, item); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
src, err := os.Open(srcPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer src.Close()
|
||||
|
||||
dest, err := os.Create(destPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer dest.Close()
|
||||
|
||||
_, err = io.Copy(dest, src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// copyTestPackage copies test package to a temporary directory
|
||||
func copyTestPackage(srcPath string) (string, error) {
|
||||
tmpDir, err := ioutil.TempDir("", "moq-tests")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
info, err := os.Lstat(srcPath)
|
||||
if err != nil {
|
||||
return tmpDir, err
|
||||
}
|
||||
return tmpDir, copy(srcPath, tmpDir, info)
|
||||
}
|
||||
|
||||
func TestModulesSamePackage(t *testing.T) {
|
||||
tmpDir, err := copyTestPackage("testpackages/modules")
|
||||
defer os.RemoveAll(tmpDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Test package copy error: %s", err)
|
||||
}
|
||||
|
||||
m, err := New(tmpDir, "")
|
||||
if err != nil {
|
||||
t.Fatalf("moq.New: %s", err)
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
err = m.Mock(&buf, "Foo")
|
||||
if err != nil {
|
||||
t.Errorf("m.Mock: %s", err)
|
||||
}
|
||||
s := buf.String()
|
||||
if strings.Contains(s, `github.com/matryer/modules`) {
|
||||
t.Errorf("should not have cyclic dependency")
|
||||
}
|
||||
// assertions of things that should be mentioned
|
||||
var strs = []string{
|
||||
"package simple",
|
||||
"var _ Foo = &FooMock{}",
|
||||
"type FooMock struct",
|
||||
}
|
||||
for _, str := range strs {
|
||||
if !strings.Contains(s, str) {
|
||||
t.Errorf("expected but missing: \"%s\"", str)
|
||||
}
|
||||
}
|
||||
}
|
||||
func TestModulesNestedPackage(t *testing.T) {
|
||||
tmpDir, err := copyTestPackage("testpackages/modules")
|
||||
defer os.RemoveAll(tmpDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Test package copy error: %s", err)
|
||||
}
|
||||
|
||||
m, err := New(tmpDir, "nested")
|
||||
if err != nil {
|
||||
t.Fatalf("moq.New: %s", err)
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
err = m.Mock(&buf, "Foo")
|
||||
if err != nil {
|
||||
t.Errorf("m.Mock: %s", err)
|
||||
}
|
||||
s := buf.String()
|
||||
// assertions of things that should be mentioned
|
||||
var strs = []string{
|
||||
"package nested",
|
||||
"github.com/matryer/modules",
|
||||
"var _ simple.Foo = &FooMock{}",
|
||||
"type FooMock struct",
|
||||
"func (mock *FooMock) FooIt(bar *simple.Bar) {",
|
||||
}
|
||||
for _, str := range strs {
|
||||
if !strings.Contains(s, str) {
|
||||
t.Errorf("expected but missing: \"%s\"", str)
|
||||
}
|
||||
}
|
||||
}
|
1
pkg/moq/testpackages/modules/go.mod
Normal file
1
pkg/moq/testpackages/modules/go.mod
Normal file
@ -0,0 +1 @@
|
||||
module github.com/matryer/modules
|
7
pkg/moq/testpackages/modules/simple.go
Normal file
7
pkg/moq/testpackages/modules/simple.go
Normal file
@ -0,0 +1,7 @@
|
||||
package simple
|
||||
|
||||
type Foo interface {
|
||||
FooIt(bar *Bar)
|
||||
}
|
||||
|
||||
type Bar struct{}
|
Loading…
Reference in New Issue
Block a user