fixed #3 - correct type names

This commit is contained in:
Mat Ryer 2016-09-21 21:39:26 +01:00
parent 86b99f5859
commit 37723c8bd4
3 changed files with 101 additions and 24 deletions

12
main.go
View File

@ -1,11 +1,9 @@
package main package main
import (
"github.com/matryer/moq/package/moq"
)
func main() { func main() {
// var (
m := moq.New() // )
// out := os.Stdout
// m := moq.New(".")
// m.Mock(out, os.Args...)
} }

View File

@ -1,6 +1,7 @@
package moq package moq
import ( import (
"errors"
"fmt" "fmt"
"go/ast" "go/ast"
"go/importer" "go/importer"
@ -19,10 +20,11 @@ type Mocker struct {
tmpl *template.Template tmpl *template.Template
fset *token.FileSet fset *token.FileSet
pkgs map[string]*ast.Package pkgs map[string]*ast.Package
pkgName string
} }
// New makes a new Mocker for the specified package directory. // New makes a new Mocker for the specified package directory.
func New(src string) (*Mocker, error) { func New(src, packageName string) (*Mocker, error) {
fset := token.NewFileSet() fset := token.NewFileSet()
noTestFiles := func(i os.FileInfo) bool { noTestFiles := func(i os.FileInfo) bool {
return !strings.HasSuffix(i.Name(), "_test.go") return !strings.HasSuffix(i.Name(), "_test.go")
@ -31,6 +33,18 @@ func New(src string) (*Mocker, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(packageName) == 0 {
for pkgName := range pkgs {
if strings.Contains(pkgName, "_test") {
continue
}
packageName = pkgName
break
}
}
if len(packageName) == 0 {
return nil, errors.New("failed to determine package name")
}
tmpl, err := template.New("moq").Parse(moqTemplate) tmpl, err := template.New("moq").Parse(moqTemplate)
if err != nil { if err != nil {
return nil, err return nil, err
@ -40,6 +54,7 @@ func New(src string) (*Mocker, error) {
tmpl: tmpl, tmpl: tmpl,
fset: fset, fset: fset,
pkgs: pkgs, pkgs: pkgs,
pkgName: packageName,
}, nil }, nil
} }
@ -74,20 +89,33 @@ func (m *Mocker) Mock(w io.Writer, name ...string) error {
Name: meth.Name(), Name: meth.Name(),
} }
obj.Methods = append(obj.Methods, method) obj.Methods = append(obj.Methods, method)
method.Params = extractArgs(sig.Params(), "in%d") method.Params = m.extractArgs(sig.Params(), "in%d")
method.Returns = extractArgs(sig.Results(), "out%d") method.Returns = m.extractArgs(sig.Results(), "out%d")
} }
objs = append(objs, obj) objs = append(objs, obj)
} }
} }
err := m.tmpl.Execute(w, struct{ Objs []*obj }{Objs: objs}) err := m.tmpl.Execute(w, struct {
PackageName string
Objs []*obj
}{
PackageName: m.pkgName,
Objs: objs,
})
if err != nil { if err != nil {
return err return err
} }
return nil return nil
} }
func extractArgs(list *types.Tuple, nameFormat string) []*param { func (m *Mocker) packageQualifier(pkg *types.Package) string {
if m.pkgName == pkg.Name() {
return ""
}
return pkg.Name()
}
func (m *Mocker) extractArgs(list *types.Tuple, nameFormat string) []*param {
var params []*param var params []*param
for ii := 0; ii < list.Len(); ii++ { for ii := 0; ii < list.Len(); ii++ {
p := list.At(ii) p := list.At(ii)
@ -95,9 +123,10 @@ func extractArgs(list *types.Tuple, nameFormat string) []*param {
if name == "" { if name == "" {
name = fmt.Sprintf(nameFormat, ii+1) name = fmt.Sprintf(nameFormat, ii+1)
} }
typename := types.TypeString(p.Type(), m.packageQualifier)
param := &param{ param := &param{
Name: name, Name: name,
Type: p.Type().String(), Type: typename,
} }
params = append(params, param) params = append(params, param)
} }
@ -155,7 +184,7 @@ func (p param) TypeString() string {
} }
var moqTemplate = ` var moqTemplate = `
package todo package {{.PackageName}}
// AUTOGENERATED BY MOQ // AUTOGENERATED BY MOQ
// github.com/matryer/moq // github.com/matryer/moq

View File

@ -3,12 +3,13 @@ package moq
import ( import (
"bytes" "bytes"
"log" "log"
"strings"
"testing" "testing"
) )
func TestMoq(t *testing.T) { func TestMoq(t *testing.T) {
m, err := New("../../example") m, err := New("../../example", "")
if err != nil { if err != nil {
t.Errorf("moq.New: %s", err) t.Errorf("moq.New: %s", err)
} }
@ -17,6 +18,55 @@ func TestMoq(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("m.Mock: %s", err) t.Errorf("m.Mock: %s", err)
} }
log.Println(buf.String()) s := buf.String()
log.Println(s)
// assertions of things that should be mentioned
var strs = []string{
"package example",
"type PersonStoreMock struct",
"CreateFunc func(ctx context.Context, person *Person, confirm bool) error",
"GetFunc func(ctx context.Context, id string) (*Person, error)",
"func (mock *PersonStoreMock) Create(ctx context.Context, person *Person, confirm bool) error",
"func (mock *PersonStoreMock) Get(ctx context.Context, id string) (*Person, error)",
}
for _, str := range strs {
if !strings.Contains(s, str) {
t.Errorf("expected but missing: \"%s\"", str)
}
}
}
func TestMoqExplicitPackage(t *testing.T) {
m, err := New("../../example", "different")
if err != nil {
t.Errorf("moq.New: %s", err)
}
var buf bytes.Buffer
err = m.Mock(&buf, "PersonStore")
if err != nil {
t.Errorf("m.Mock: %s", err)
}
s := buf.String()
log.Println(s)
// assertions of things that should be mentioned
var strs = []string{
"package different",
"type PersonStoreMock struct",
"CreateFunc func(ctx context.Context, person *example.Person, confirm bool) error",
"GetFunc func(ctx context.Context, id string) (*example.Person, error)",
"func (mock *PersonStoreMock) Create(ctx context.Context, person *example.Person, confirm bool) error",
"func (mock *PersonStoreMock) Get(ctx context.Context, id string) (*example.Person, error)",
}
for _, str := range strs {
if !strings.Contains(s, str) {
t.Errorf("expected but missing: \"%s\"", str)
}
}
} }