fixed imports and improved counters

This commit is contained in:
Mat Ryer 2017-07-07 14:38:24 +01:00
parent 3dbdbe86c1
commit ae57d77f27
4 changed files with 71 additions and 13 deletions

View File

@ -21,6 +21,8 @@ type Mocker struct {
fset *token.FileSet fset *token.FileSet
pkgs map[string]*ast.Package pkgs map[string]*ast.Package
pkgName string pkgName string
imports []string
} }
// New makes a new Mocker for the specified package directory. // New makes a new Mocker for the specified package directory.
@ -63,7 +65,10 @@ func (m *Mocker) Mock(w io.Writer, name ...string) error {
if len(name) == 0 { if len(name) == 0 {
return errors.New("must specify one interface") return errors.New("must specify one interface")
} }
var objs []*obj doc := doc{
PackageName: m.pkgName,
Imports: moqImports,
}
for _, pkg := range m.pkgs { for _, pkg := range m.pkgs {
i := 0 i := 0
files := make([]*ast.File, len(pkg.Files)) files := make([]*ast.File, len(pkg.Files))
@ -85,7 +90,7 @@ func (m *Mocker) Mock(w io.Writer, name ...string) error {
return fmt.Errorf("%s (%s) not an interface", n, iface.Type().String()) return fmt.Errorf("%s (%s) not an interface", n, iface.Type().String())
} }
iiface := iface.Type().Underlying().(*types.Interface).Complete() iiface := iface.Type().Underlying().(*types.Interface).Complete()
obj := &obj{ obj := obj{
InterfaceName: n, InterfaceName: n,
} }
for i := 0; i < iiface.NumMethods(); i++ { for i := 0; i < iiface.NumMethods(); i++ {
@ -98,16 +103,11 @@ func (m *Mocker) Mock(w io.Writer, name ...string) error {
method.Params = m.extractArgs(sig, sig.Params(), "in%d") method.Params = m.extractArgs(sig, sig.Params(), "in%d")
method.Returns = m.extractArgs(sig, sig.Results(), "out%d") method.Returns = m.extractArgs(sig, sig.Results(), "out%d")
} }
objs = append(objs, obj) doc.Objects = append(doc.Objects, obj)
} }
} }
err := m.tmpl.Execute(w, struct { doc.Imports = append(doc.Imports, m.imports...)
PackageName string err := m.tmpl.Execute(w, doc)
Objs []*obj
}{
PackageName: m.pkgName,
Objs: objs,
})
if err != nil { if err != nil {
return err return err
} }
@ -118,6 +118,7 @@ func (m *Mocker) packageQualifier(pkg *types.Package) string {
if m.pkgName == pkg.Name() { if m.pkgName == pkg.Name() {
return "" return ""
} }
m.imports = append(m.imports, pkg.Path())
return pkg.Name() return pkg.Name()
} }
@ -143,6 +144,12 @@ func (m *Mocker) extractArgs(sig *types.Signature, list *types.Tuple, nameFormat
return params return params
} }
type doc struct {
PackageName string
Objects []obj
Imports []string
}
type obj struct { type obj struct {
InterfaceName string InterfaceName string
Methods []*method Methods []*method
@ -204,11 +211,21 @@ func (p param) TypeString() string {
return p.Type return p.Type
} }
// moqImports are the imports all moq files get.
var moqImports = []string{"sync"}
// moqTemplate is the template for mocked code.
var moqTemplate = `package {{.PackageName}} var moqTemplate = `package {{.PackageName}}
// AUTOGENERATED BY MOQ // AUTOGENERATED BY MOQ
// github.com/matryer/moq // github.com/matryer/moq
{{ range $i, $obj := .Objs }}
import (
{{- range .Imports }}
"{{.}}"
{{- end }}
}
{{ range $i, $obj := .Objects }}
// {{.InterfaceName}}Mock is a mock implementation of {{.InterfaceName}}. // {{.InterfaceName}}Mock is a mock implementation of {{.InterfaceName}}.
// //
// func TestSomethingThatUses{{.InterfaceName}}(t *testing.T) { // func TestSomethingThatUses{{.InterfaceName}}(t *testing.T) {
@ -232,9 +249,10 @@ type {{.InterfaceName}}Mock struct {
// CallsTo gets counters for each of the methods indicating // CallsTo gets counters for each of the methods indicating
// how many times each one was called. // how many times each one was called.
CallsTo struct { CallsTo struct {
lock sync.Mutex
{{- range .Methods }} {{- range .Methods }}
// {{ .Name }} holds the number of calls to the {{.Name}} method. // {{ .Name }} holds the number of calls to the {{.Name}} method.
{{ .Name }} uint64 {{ .Name }} int
{{- end }} {{- end }}
} }
} }
@ -244,7 +262,9 @@ func (mock *{{$obj.InterfaceName}}Mock) {{.Name}}({{.Arglist}}) {{.ReturnArglist
if mock.{{.Name}}Func == nil { if mock.{{.Name}}Func == nil {
panic("moq: {{$obj.InterfaceName}}Mock.{{.Name}}Func is nil but was just called") panic("moq: {{$obj.InterfaceName}}Mock.{{.Name}}Func is nil but was just called")
} }
atomic.AddUint64(&mock.CallsTo.{{.Name}}, 1) // count this mock.CallsTo.lock.Lock()
mock.CallsTo.{{.Name}}++
mock.CallsTo.lock.Unlock()
{{- if .ReturnArglist }} {{- if .ReturnArglist }}
return mock.{{.Name}}Func({{.ArgCallList}}) return mock.{{.Name}}Func({{.ArgCallList}})
{{- else }} {{- else }}

View File

@ -2,6 +2,7 @@ package moq
import ( import (
"bytes" "bytes"
"log"
"strings" "strings"
"testing" "testing"
) )
@ -36,6 +37,7 @@ func TestMoq(t *testing.T) {
t.Errorf("expected but missing: \"%s\"", str) t.Errorf("expected but missing: \"%s\"", str)
} }
} }
log.Println(s)
} }
func TestMoqExplicitPackage(t *testing.T) { func TestMoqExplicitPackage(t *testing.T) {
@ -138,3 +140,25 @@ func TestChannelNames(t *testing.T) {
} }
} }
} }
func TestImports(t *testing.T) {
m, err := New("testdata/imports/two", "")
if err != nil {
t.Errorf("moq.New: %s", err)
}
var buf bytes.Buffer
err = m.Mock(&buf, "DoSomething")
if err != nil {
t.Errorf("m.Mock: %s", err)
}
s := buf.String()
var strs = []string{
` "sync"`,
` "github.com/matryer/moq/package/moq/testdata/imports/one"`,
}
for _, str := range strs {
if !strings.Contains(s, str) {
t.Errorf("expected by missing: \"%s\"", str)
}
}
}

View File

@ -0,0 +1,4 @@
package one
// Thing is just a thing.
type Thing struct{}

10
package/moq/testdata/imports/two/two.go vendored Normal file
View File

@ -0,0 +1,10 @@
package two
import (
"github.com/matryer/moq/package/moq/testdata/imports/one"
)
// DoSomething does something.
type DoSomething interface {
Do(thing one.Thing) error
}