diff --git a/.travis.yml b/.travis.yml index 78c3c72..9b3ca76 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,7 +3,8 @@ language: go sudo: false go: - - 1.7.1 + - 1.7.x + - 1.8.x - tip before_install: diff --git a/README.md b/README.md index 36fde64..b0f06af 100644 --- a/README.md +++ b/README.md @@ -63,12 +63,10 @@ This this example, Moq generated the `EmailSenderMock` type: ```go func TestCompleteSignup(t *testing.T) { - called := false var sentTo string mockedEmailSender = &EmailSenderMock{ SendFunc: func(to, subject, body string) error { - called = true sentTo = to return nil }, @@ -76,8 +74,8 @@ func TestCompleteSignup(t *testing.T) { CompleteSignUp("me@email.com", mockedEmailSender) - if called == false { - t.Error("Sender.Send expected") + if mockedEmailSender.CallsTo.Send != 1 { + t.Errorf("Send was called %d times", mockedEmailSender.CallsTo.Send) } if sentTo != "me@email.com" { t.Errorf("unexpected recipient: %s", sentTo) @@ -95,8 +93,11 @@ The mocked structure implements the interface, where each method calls the assoc ## Tips * Keep mocked logic inside the test that is using it -* Only mock the fields you need - it will panic if a nil function gets called +* Only mock the fields you need +* It will panic if a nil function gets called +* Name arguments in the interface for a better experience * Use closured variables inside your test function to capture details about the calls to the methods +* Use `.CallsTo.Method` to track the calls * Use `go:generate` to invoke the `moq` command ## License diff --git a/package/moq/importer.go b/package/moq/importer.go new file mode 100644 index 0000000..66c6109 --- /dev/null +++ b/package/moq/importer.go @@ -0,0 +1,145 @@ +package moq + +// taken from https://github.com/ernesto-jimenez/gogen +// Copyright (c) 2015 Ernesto Jiménez + +import ( + "fmt" + "go/ast" + "go/importer" + "go/parser" + "go/token" + "go/types" + "io/ioutil" + "os" + "path" + "path/filepath" + "strings" +) + +type customImporter struct { + imported map[string]*types.Package + base types.Importer + skipTestFiles bool +} + +func (i *customImporter) Import(path string) (*types.Package, error) { + var err error + if path == "" || path[0] == '.' { + path, err = filepath.Abs(filepath.Clean(path)) + if err != nil { + return nil, err + } + path = stripGopath(path) + } + if pkg, ok := i.imported[path]; ok { + return pkg, nil + } + pkg, err := i.fsPkg(path) + if err != nil { + return nil, err + } + i.imported[path] = pkg + return pkg, nil +} + +func gopathDir(pkg string) (string, error) { + for _, gopath := range strings.Split(os.Getenv("GOPATH"), ":") { + absPath, err := filepath.Abs(path.Join(gopath, "src", pkg)) + if err != nil { + return "", err + } + if dir, err := os.Stat(absPath); err == nil && dir.IsDir() { + return absPath, nil + } + } + return "", fmt.Errorf("%s not in $GOPATH", pkg) +} + +func removeGopath(p string) string { + for _, gopath := range strings.Split(os.Getenv("GOPATH"), ":") { + p = strings.Replace(p, path.Join(gopath, "src")+"/", "", 1) + } + return p +} + +func (i *customImporter) fsPkg(pkg string) (*types.Package, error) { + dir, err := gopathDir(pkg) + if err != nil { + return importOrErr(i.base, pkg, err) + } + + dirFiles, err := ioutil.ReadDir(dir) + if err != nil { + return importOrErr(i.base, pkg, err) + } + + fset := token.NewFileSet() + var files []*ast.File + for _, fileInfo := range dirFiles { + if fileInfo.IsDir() { + continue + } + n := fileInfo.Name() + if path.Ext(fileInfo.Name()) != ".go" { + continue + } + if i.skipTestFiles && strings.Contains(fileInfo.Name(), "_test.go") { + continue + } + file := path.Join(dir, n) + src, err := ioutil.ReadFile(file) + if err != nil { + return nil, err + } + f, err := parser.ParseFile(fset, file, src, 0) + if err != nil { + return nil, err + } + files = append(files, f) + } + conf := types.Config{ + Importer: i, + } + p, err := conf.Check(pkg, fset, files, nil) + + if err != nil { + return importOrErr(i.base, pkg, err) + } + return p, nil +} + +func importOrErr(base types.Importer, pkg string, err error) (*types.Package, error) { + p, impErr := base.Import(pkg) + if impErr != nil { + return nil, err + } + return p, nil +} + +// newImporter returns an importer that will try to import code from gopath before using go/importer.Default and skipping test files +func newImporter() types.Importer { + return &customImporter{ + imported: make(map[string]*types.Package), + base: importer.Default(), + skipTestFiles: true, + } +} + +// // DefaultWithTestFiles same as Default but it parses test files too +// func DefaultWithTestFiles() types.Importer { +// return &customImporter{ +// imported: make(map[string]*types.Package), +// base: importer.Default(), +// skipTestFiles: false, +// } +// } + +// stripGopath teks the directory to a package and remove the gopath to get the +// cannonical package name +func stripGopath(p string) string { + for _, gopath := range strings.Split(os.Getenv("GOPATH"), ":") { + p = strings.Replace(p, path.Join(gopath, "src")+"/", "", 1) + } + return p +} diff --git a/package/moq/moq.go b/package/moq/moq.go index 0eff431..be3af3d 100644 --- a/package/moq/moq.go +++ b/package/moq/moq.go @@ -4,7 +4,6 @@ import ( "errors" "fmt" "go/ast" - "go/importer" "go/parser" "go/token" "go/types" @@ -21,6 +20,8 @@ type Mocker struct { fset *token.FileSet pkgs map[string]*ast.Package pkgName string + + imports map[string]bool } // New makes a new Mocker for the specified package directory. @@ -45,7 +46,7 @@ func New(src, packageName string) (*Mocker, error) { if len(packageName) == 0 { return nil, errors.New("failed to determine package name") } - tmpl, err := template.New("moq").Parse(moqTemplate) + tmpl, err := template.New("moq").Funcs(templateFuncs).Parse(moqTemplate) if err != nil { return nil, err } @@ -55,6 +56,7 @@ func New(src, packageName string) (*Mocker, error) { fset: fset, pkgs: pkgs, pkgName: packageName, + imports: make(map[string]bool), }, nil } @@ -63,7 +65,10 @@ func (m *Mocker) Mock(w io.Writer, name ...string) error { if len(name) == 0 { return errors.New("must specify one interface") } - var objs []*obj + doc := doc{ + PackageName: m.pkgName, + Imports: moqImports, + } for _, pkg := range m.pkgs { i := 0 files := make([]*ast.File, len(pkg.Files)) @@ -71,7 +76,7 @@ func (m *Mocker) Mock(w io.Writer, name ...string) error { files[i] = file i++ } - conf := types.Config{Importer: importer.Default()} + conf := types.Config{Importer: newImporter()} tpkg, err := conf.Check(m.src, m.fset, files, nil) if err != nil { return err @@ -85,7 +90,7 @@ func (m *Mocker) Mock(w io.Writer, name ...string) error { return fmt.Errorf("%s (%s) not an interface", n, iface.Type().String()) } iiface := iface.Type().Underlying().(*types.Interface).Complete() - obj := &obj{ + obj := obj{ InterfaceName: n, } for i := 0; i < iiface.NumMethods(); i++ { @@ -98,16 +103,13 @@ func (m *Mocker) Mock(w io.Writer, name ...string) error { method.Params = m.extractArgs(sig, sig.Params(), "in%d") method.Returns = m.extractArgs(sig, sig.Results(), "out%d") } - objs = append(objs, obj) + doc.Objects = append(doc.Objects, obj) } } - err := m.tmpl.Execute(w, struct { - PackageName string - Objs []*obj - }{ - PackageName: m.pkgName, - Objs: objs, - }) + for pkgToImport := range m.imports { + doc.Imports = append(doc.Imports, pkgToImport) + } + err := m.tmpl.Execute(w, doc) if err != nil { return err } @@ -118,6 +120,7 @@ func (m *Mocker) packageQualifier(pkg *types.Package) string { if m.pkgName == pkg.Name() { return "" } + m.imports[pkg.Path()] = true return pkg.Name() } @@ -143,6 +146,12 @@ func (m *Mocker) extractArgs(sig *types.Signature, list *types.Tuple, nameFormat return params } +type doc struct { + PackageName string + Objects []obj + Imports []string +} + type obj struct { InterfaceName string Methods []*method @@ -204,11 +213,30 @@ func (p param) TypeString() string { return p.Type } +var templateFuncs = template.FuncMap{ + "Exported": func(s string) string { + if s == "" { + return "" + } + return strings.ToUpper(s[0:1]) + s[1:] + }, +} + +// moqImports are the imports all moq files get. +var moqImports = []string{"sync"} + +// moqTemplate is the template for mocked code. var moqTemplate = `package {{.PackageName}} // AUTOGENERATED BY MOQ // github.com/matryer/moq -{{ range $i, $obj := .Objs }} + +import ( +{{- range .Imports }} + "{{.}}" +{{- end }} +) +{{ range $i, $obj := .Objects }} // {{.InterfaceName}}Mock is a mock implementation of {{.InterfaceName}}. // // func TestSomethingThatUses{{.InterfaceName}}(t *testing.T) { @@ -221,13 +249,38 @@ var moqTemplate = `package {{.PackageName}} // } // // // TODO: use mocked{{.InterfaceName}} in code that requires {{.InterfaceName}} -// +// // and then make assertions. +// // +// // Use the CallsTo structure to access details about what calls were made: +// // +// // if len(mocked{{.InterfaceName}}.CallsTo.MethodFunc) != 1 { +// // t.Errorf("expected 1 call there were %d", len(mocked{{.InterfaceName}}.CallsTo.MethodFunc)) +// // } +// // } type {{.InterfaceName}}Mock struct { {{- range .Methods }} // {{.Name}}Func mocks the {{.Name}} method. {{.Name}}Func func({{ .Arglist }}) {{.ReturnArglist}} +{{ end }} + // CallsTo tracks calls to the methods. + CallsTo struct { + // Enabled indicates that calls will be tracked. + // + // // don't track calls + // {{.InterfaceName}}Mock.CallsTo.Enabled = false + Enabled bool +{{ range .Methods }} + lock{{.Name}} sync.Mutex // protects {{ .Name }} + // {{ .Name }} holds details about calls to the {{.Name}} method. + {{ .Name }} []struct { + {{- range .Params }} + // {{ .Name | Exported }} is the {{ .Name }} argument value. + {{ .Name | Exported }} {{ .Type }} + {{- end }} + } {{- end }} + } } {{ range .Methods }} // {{.Name}} calls {{.Name}}Func. @@ -235,6 +288,19 @@ func (mock *{{$obj.InterfaceName}}Mock) {{.Name}}({{.Arglist}}) {{.ReturnArglist if mock.{{.Name}}Func == nil { panic("moq: {{$obj.InterfaceName}}Mock.{{.Name}}Func is nil but was just called") } + if mock.CallsTo.Enabled { + mock.CallsTo.lock{{.Name}}.Lock() + mock.CallsTo.{{.Name}} = append(mock.CallsTo.{{.Name}}, struct{ + {{- range .Params }} + {{ .Name | Exported }} {{ .Type }} + {{- end }} + }{ + {{- range .Params }} + {{ .Name | Exported }}: {{ .Name }}, + {{- end }} + }) + mock.CallsTo.lock{{.Name}}.Unlock() + } {{- if .ReturnArglist }} return mock.{{.Name}}Func({{.ArgCallList}}) {{- else }} diff --git a/package/moq/moq_test.go b/package/moq/moq_test.go index 0372756..1fad8df 100644 --- a/package/moq/moq_test.go +++ b/package/moq/moq_test.go @@ -9,7 +9,7 @@ import ( func TestMoq(t *testing.T) { m, err := New("testdata/example", "") if err != nil { - t.Errorf("moq.New: %s", err) + t.Fatalf("moq.New: %s", err) } var buf bytes.Buffer err = m.Mock(&buf, "PersonStore") @@ -27,18 +27,22 @@ func TestMoq(t *testing.T) { "func (mock *PersonStoreMock) Get(ctx context.Context, id string) (*Person, error)", "panic(\"moq: PersonStoreMock.CreateFunc is nil but was just called\")", "panic(\"moq: PersonStoreMock.GetFunc is nil but was just called\")", + "mock.CallsTo.lockGet.Lock()", + "mock.CallsTo.Get = append(mock.CallsTo.Get, struct{", + "mock.CallsTo.lockGet.Unlock()", } for _, str := range strs { if !strings.Contains(s, str) { t.Errorf("expected but missing: \"%s\"", str) } } + } func TestMoqExplicitPackage(t *testing.T) { m, err := New("testdata/example", "different") if err != nil { - t.Errorf("moq.New: %s", err) + t.Fatalf("moq.New: %s", err) } var buf bytes.Buffer err = m.Mock(&buf, "PersonStore") @@ -68,7 +72,7 @@ func TestMoqExplicitPackage(t *testing.T) { func TestVariadicArguments(t *testing.T) { m, err := New("testdata/variadic", "") if err != nil { - t.Errorf("moq.New: %s", err) + t.Fatalf("moq.New: %s", err) } var buf bytes.Buffer err = m.Mock(&buf, "Greeter") @@ -93,7 +97,7 @@ func TestVariadicArguments(t *testing.T) { func TestNothingToReturn(t *testing.T) { m, err := New("testdata/example", "") if err != nil { - t.Errorf("moq.New: %s", err) + t.Fatalf("moq.New: %s", err) } var buf bytes.Buffer err = m.Mock(&buf, "PersonStore") @@ -118,7 +122,7 @@ func TestNothingToReturn(t *testing.T) { func TestChannelNames(t *testing.T) { m, err := New("testdata/channels", "") if err != nil { - t.Errorf("moq.New: %s", err) + t.Fatalf("moq.New: %s", err) } var buf bytes.Buffer err = m.Mock(&buf, "Queuer") @@ -131,7 +135,39 @@ func TestChannelNames(t *testing.T) { } for _, str := range strs { if !strings.Contains(s, str) { - t.Errorf("expected by missing: \"%s\"", str) + t.Errorf("expected but missing: \"%s\"", str) } } } + +func TestImports(t *testing.T) { + m, err := New("testdata/imports/two", "") + if err != nil { + t.Fatalf("moq.New: %s", err) + } + var buf bytes.Buffer + err = m.Mock(&buf, "DoSomething") + if err != nil { + t.Errorf("m.Mock: %s", err) + } + s := buf.String() + var strs = []string{ + ` "sync"`, + ` "github.com/matryer/moq/package/moq/testdata/imports/one"`, + } + for _, str := range strs { + if !strings.Contains(s, str) { + t.Errorf("expected but missing: \"%s\"", str) + } + if len(strings.Split(s, str)) > 2 { + t.Errorf("more than one: \"%s\"", str) + } + } +} + +func TestTemplateFuncs(t *testing.T) { + fn := templateFuncs["Exported"].(func(string) string) + if fn("var") != "Var" { + t.Errorf("exported didn't work: %s", fn("var")) + } +} diff --git a/package/moq/testdata/imports/one/one.go b/package/moq/testdata/imports/one/one.go new file mode 100644 index 0000000..c38a014 --- /dev/null +++ b/package/moq/testdata/imports/one/one.go @@ -0,0 +1,4 @@ +package one + +// Thing is just a thing. +type Thing struct{} diff --git a/package/moq/testdata/imports/two/two.go b/package/moq/testdata/imports/two/two.go new file mode 100644 index 0000000..a7635a6 --- /dev/null +++ b/package/moq/testdata/imports/two/two.go @@ -0,0 +1,11 @@ +package two + +import ( + "github.com/matryer/moq/package/moq/testdata/imports/one" +) + +// DoSomething does something. +type DoSomething interface { + Do(thing one.Thing) error + Another(thing one.Thing) error +} diff --git a/preview.png b/preview.png index dfa54f7..73a242f 100644 Binary files a/preview.png and b/preview.png differ