diff --git a/package/moq/moq.go b/package/moq/moq.go index b647076..34ac95b 100644 --- a/package/moq/moq.go +++ b/package/moq/moq.go @@ -21,6 +21,8 @@ type Mocker struct { fset *token.FileSet pkgs map[string]*ast.Package pkgName string + + imports []string } // New makes a new Mocker for the specified package directory. @@ -63,7 +65,10 @@ func (m *Mocker) Mock(w io.Writer, name ...string) error { if len(name) == 0 { return errors.New("must specify one interface") } - var objs []*obj + doc := doc{ + PackageName: m.pkgName, + Imports: moqImports, + } for _, pkg := range m.pkgs { i := 0 files := make([]*ast.File, len(pkg.Files)) @@ -85,7 +90,7 @@ func (m *Mocker) Mock(w io.Writer, name ...string) error { return fmt.Errorf("%s (%s) not an interface", n, iface.Type().String()) } iiface := iface.Type().Underlying().(*types.Interface).Complete() - obj := &obj{ + obj := obj{ InterfaceName: n, } for i := 0; i < iiface.NumMethods(); i++ { @@ -98,16 +103,11 @@ func (m *Mocker) Mock(w io.Writer, name ...string) error { method.Params = m.extractArgs(sig, sig.Params(), "in%d") method.Returns = m.extractArgs(sig, sig.Results(), "out%d") } - objs = append(objs, obj) + doc.Objects = append(doc.Objects, obj) } } - err := m.tmpl.Execute(w, struct { - PackageName string - Objs []*obj - }{ - PackageName: m.pkgName, - Objs: objs, - }) + doc.Imports = append(doc.Imports, m.imports...) + err := m.tmpl.Execute(w, doc) if err != nil { return err } @@ -118,6 +118,7 @@ func (m *Mocker) packageQualifier(pkg *types.Package) string { if m.pkgName == pkg.Name() { return "" } + m.imports = append(m.imports, pkg.Path()) return pkg.Name() } @@ -143,6 +144,12 @@ func (m *Mocker) extractArgs(sig *types.Signature, list *types.Tuple, nameFormat return params } +type doc struct { + PackageName string + Objects []obj + Imports []string +} + type obj struct { InterfaceName string Methods []*method @@ -204,11 +211,21 @@ func (p param) TypeString() string { return p.Type } +// moqImports are the imports all moq files get. +var moqImports = []string{"sync"} + +// moqTemplate is the template for mocked code. var moqTemplate = `package {{.PackageName}} // AUTOGENERATED BY MOQ // github.com/matryer/moq -{{ range $i, $obj := .Objs }} + +import ( +{{- range .Imports }} + "{{.}}" +{{- end }} +} +{{ range $i, $obj := .Objects }} // {{.InterfaceName}}Mock is a mock implementation of {{.InterfaceName}}. // // func TestSomethingThatUses{{.InterfaceName}}(t *testing.T) { @@ -232,9 +249,10 @@ type {{.InterfaceName}}Mock struct { // CallsTo gets counters for each of the methods indicating // how many times each one was called. CallsTo struct { + lock sync.Mutex {{- range .Methods }} // {{ .Name }} holds the number of calls to the {{.Name}} method. - {{ .Name }} uint64 + {{ .Name }} int {{- end }} } } @@ -244,7 +262,9 @@ func (mock *{{$obj.InterfaceName}}Mock) {{.Name}}({{.Arglist}}) {{.ReturnArglist if mock.{{.Name}}Func == nil { panic("moq: {{$obj.InterfaceName}}Mock.{{.Name}}Func is nil but was just called") } - atomic.AddUint64(&mock.CallsTo.{{.Name}}, 1) // count this + mock.CallsTo.lock.Lock() + mock.CallsTo.{{.Name}}++ + mock.CallsTo.lock.Unlock() {{- if .ReturnArglist }} return mock.{{.Name}}Func({{.ArgCallList}}) {{- else }} diff --git a/package/moq/moq_test.go b/package/moq/moq_test.go index 36ff5b2..2c86670 100644 --- a/package/moq/moq_test.go +++ b/package/moq/moq_test.go @@ -2,6 +2,7 @@ package moq import ( "bytes" + "log" "strings" "testing" ) @@ -36,6 +37,7 @@ func TestMoq(t *testing.T) { t.Errorf("expected but missing: \"%s\"", str) } } + log.Println(s) } func TestMoqExplicitPackage(t *testing.T) { @@ -138,3 +140,25 @@ func TestChannelNames(t *testing.T) { } } } + +func TestImports(t *testing.T) { + m, err := New("testdata/imports/two", "") + if err != nil { + t.Errorf("moq.New: %s", err) + } + var buf bytes.Buffer + err = m.Mock(&buf, "DoSomething") + if err != nil { + t.Errorf("m.Mock: %s", err) + } + s := buf.String() + var strs = []string{ + ` "sync"`, + ` "github.com/matryer/moq/package/moq/testdata/imports/one"`, + } + for _, str := range strs { + if !strings.Contains(s, str) { + t.Errorf("expected by missing: \"%s\"", str) + } + } +} diff --git a/package/moq/testdata/imports/one/one.go b/package/moq/testdata/imports/one/one.go new file mode 100644 index 0000000..c38a014 --- /dev/null +++ b/package/moq/testdata/imports/one/one.go @@ -0,0 +1,4 @@ +package one + +// Thing is just a thing. +type Thing struct{} diff --git a/package/moq/testdata/imports/two/two.go b/package/moq/testdata/imports/two/two.go new file mode 100644 index 0000000..66d5ca2 --- /dev/null +++ b/package/moq/testdata/imports/two/two.go @@ -0,0 +1,10 @@ +package two + +import ( + "github.com/matryer/moq/package/moq/testdata/imports/one" +) + +// DoSomething does something. +type DoSomething interface { + Do(thing one.Thing) error +}