Internal registry for disambiguated imports, vars (#141)
* Internal registry for disambiguated imports, vars - Move functionality in the moq package partially into internal/{registry,template}. - Leverage registry to assign unique package and variable/method parameter names. Use import aliases if present in interface source package. BREAKING CHANGE: When the interface definition does not mention the parameter names, the field names in call info anonymous struct will be different. The new field names are generated using the type info (string -> s, int -> n, chan int -> intCh, []MyType -> myTypes, map[string]int -> stringToInt etc.). For example, for a string parameter previously if the field name was 'In1', the new field could be 'S' or 'S1' (depends on number of string method parameters). * Refactor golden file tests to be table-driven * Fix sync pkg alias handling for moq generation * Improve, add tests (increase coverage) * Use $.Foo in template, avoid declaring variables $ is set to the data argument passed to Execute, that is, to the starting value of dot. Variables were declared to be able to refer to the parent context. * Consistent template field formatting * Use tabs in generated Godoc comments' example code * Minor simplification * go generate * Fix conflict for generated param name of pointer type Excellent work by @sudo-suhas.
This commit is contained in:
155
internal/registry/method_scope.go
Normal file
155
internal/registry/method_scope.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"go/types"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// MethodScope is the sub-registry for allocating variables present in
|
||||
// the method scope.
|
||||
//
|
||||
// It should be created using a registry instance.
|
||||
type MethodScope struct {
|
||||
registry *Registry
|
||||
moqPkgPath string
|
||||
|
||||
vars []*Var
|
||||
conflicted map[string]bool
|
||||
}
|
||||
|
||||
// AddVar allocates a variable instance and adds it to the method scope.
|
||||
//
|
||||
// Variables names are generated if required and are ensured to be
|
||||
// without conflict with other variables and imported packages. It also
|
||||
// adds the relevant imports to the registry for each added variable.
|
||||
func (m *MethodScope) AddVar(vr *types.Var, suffix string) *Var {
|
||||
name := vr.Name()
|
||||
if name == "" || name == "_" {
|
||||
name = generateVarName(vr.Type())
|
||||
}
|
||||
|
||||
name += suffix
|
||||
|
||||
switch name {
|
||||
case "mock", "callInfo", "break", "default", "func", "interface", "select", "case", "defer", "go", "map", "struct",
|
||||
"chan", "else", "goto", "package", "switch", "const", "fallthrough", "if", "range", "type", "continue", "for",
|
||||
"import", "return", "var":
|
||||
name += "MoqParam"
|
||||
}
|
||||
|
||||
if _, ok := m.searchVar(name); ok || m.conflicted[name] {
|
||||
return m.addDisambiguatedVar(vr, name)
|
||||
}
|
||||
|
||||
return m.addVar(vr, name)
|
||||
}
|
||||
|
||||
func (m *MethodScope) addDisambiguatedVar(vr *types.Var, suggested string) *Var {
|
||||
n := 1
|
||||
for {
|
||||
// Keep incrementing the suffix until we find a name which is unused.
|
||||
if _, ok := m.searchVar(suggested + strconv.Itoa(n)); !ok {
|
||||
break
|
||||
}
|
||||
n++
|
||||
}
|
||||
|
||||
name := suggested + strconv.Itoa(n)
|
||||
if n == 1 {
|
||||
conflict, _ := m.searchVar(suggested)
|
||||
conflict.Name += "1"
|
||||
name = suggested + "2"
|
||||
m.conflicted[suggested] = true
|
||||
}
|
||||
|
||||
return m.addVar(vr, name)
|
||||
}
|
||||
|
||||
func (m *MethodScope) addVar(vr *types.Var, name string) *Var {
|
||||
imports := make(map[string]*Package)
|
||||
m.populateImports(vr.Type(), imports)
|
||||
|
||||
v := Var{
|
||||
vr: vr,
|
||||
imports: imports,
|
||||
moqPkgPath: m.moqPkgPath,
|
||||
Name: name,
|
||||
}
|
||||
m.vars = append(m.vars, &v)
|
||||
m.resolveImportVarConflicts(&v)
|
||||
return &v
|
||||
}
|
||||
|
||||
func (m MethodScope) searchVar(name string) (*Var, bool) {
|
||||
for _, v := range m.vars {
|
||||
if v.Name == name {
|
||||
return v, true
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// populateImports extracts all the package imports for a given type
|
||||
// recursively. The imported packages by a single type can be more than
|
||||
// one (ex: map[a.Type]b.Type).
|
||||
func (m MethodScope) populateImports(t types.Type, imports map[string]*Package) {
|
||||
switch t := t.(type) {
|
||||
case *types.Named:
|
||||
if pkg := t.Obj().Pkg(); pkg != nil {
|
||||
imports[stripVendorPath(pkg.Path())] = m.registry.AddImport(pkg)
|
||||
}
|
||||
|
||||
case *types.Array:
|
||||
m.populateImports(t.Elem(), imports)
|
||||
|
||||
case *types.Slice:
|
||||
m.populateImports(t.Elem(), imports)
|
||||
|
||||
case *types.Signature:
|
||||
for i := 0; i < t.Params().Len(); i++ {
|
||||
m.populateImports(t.Params().At(i).Type(), imports)
|
||||
}
|
||||
for i := 0; i < t.Results().Len(); i++ {
|
||||
m.populateImports(t.Results().At(i).Type(), imports)
|
||||
}
|
||||
|
||||
case *types.Map:
|
||||
m.populateImports(t.Key(), imports)
|
||||
m.populateImports(t.Elem(), imports)
|
||||
|
||||
case *types.Chan:
|
||||
m.populateImports(t.Elem(), imports)
|
||||
|
||||
case *types.Pointer:
|
||||
m.populateImports(t.Elem(), imports)
|
||||
|
||||
case *types.Struct: // anonymous struct
|
||||
for i := 0; i < t.NumFields(); i++ {
|
||||
m.populateImports(t.Field(i).Type(), imports)
|
||||
}
|
||||
|
||||
case *types.Interface: // anonymous interface
|
||||
for i := 0; i < t.NumExplicitMethods(); i++ {
|
||||
m.populateImports(t.ExplicitMethod(i).Type(), imports)
|
||||
}
|
||||
for i := 0; i < t.NumEmbeddeds(); i++ {
|
||||
m.populateImports(t.EmbeddedType(i), imports)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m MethodScope) resolveImportVarConflicts(v *Var) {
|
||||
// Ensure that the newly added var does not conflict with a package import
|
||||
// which was added earlier.
|
||||
if _, ok := m.registry.searchImport(v.Name); ok {
|
||||
v.Name += "MoqParam"
|
||||
}
|
||||
// Ensure that all the newly added imports do not conflict with any of the
|
||||
// existing vars.
|
||||
for _, imprt := range v.imports {
|
||||
if v, ok := m.searchVar(imprt.Qualifier()); ok {
|
||||
v.Name += "MoqParam"
|
||||
}
|
||||
}
|
||||
}
|
93
internal/registry/package.go
Normal file
93
internal/registry/package.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"go/types"
|
||||
"path"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Package represents an imported package.
|
||||
type Package struct {
|
||||
pkg *types.Package
|
||||
|
||||
Alias string
|
||||
}
|
||||
|
||||
// NewPackage creates a new instance of Package.
|
||||
func NewPackage(pkg *types.Package) *Package { return &Package{pkg: pkg} }
|
||||
|
||||
// Qualifier returns the qualifier which must be used to refer to types
|
||||
// declared in the package.
|
||||
func (p *Package) Qualifier() string {
|
||||
if p == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if p.Alias != "" {
|
||||
return p.Alias
|
||||
}
|
||||
|
||||
return p.pkg.Name()
|
||||
}
|
||||
|
||||
// Path is the full package import path (without vendor).
|
||||
func (p *Package) Path() string {
|
||||
if p == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return stripVendorPath(p.pkg.Path())
|
||||
}
|
||||
|
||||
var replacer = strings.NewReplacer(
|
||||
"go-", "",
|
||||
"-go", "",
|
||||
"-", "",
|
||||
"_", "",
|
||||
".", "",
|
||||
"@", "",
|
||||
"+", "",
|
||||
"~", "",
|
||||
)
|
||||
|
||||
// uniqueName generates a unique name for a package by concatenating
|
||||
// path components. The generated name is guaranteed to unique with an
|
||||
// appropriate level because the full package import paths themselves
|
||||
// are unique.
|
||||
func (p Package) uniqueName(lvl int) string {
|
||||
pp := strings.Split(p.Path(), "/")
|
||||
reverse(pp)
|
||||
|
||||
var name string
|
||||
for i := 0; i < min(len(pp), lvl+1); i++ {
|
||||
name = strings.ToLower(replacer.Replace(pp[i])) + name
|
||||
}
|
||||
|
||||
return name
|
||||
}
|
||||
|
||||
// stripVendorPath strips the vendor dir prefix from a package path.
|
||||
// For example we might encounter an absolute path like
|
||||
// github.com/foo/bar/vendor/github.com/pkg/errors which is resolved
|
||||
// to github.com/pkg/errors.
|
||||
func stripVendorPath(p string) string {
|
||||
parts := strings.Split(p, "/vendor/")
|
||||
if len(parts) == 1 {
|
||||
return p
|
||||
}
|
||||
return strings.TrimLeft(path.Join(parts[1:]...), "/")
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func reverse(a []string) {
|
||||
for i := len(a)/2 - 1; i >= 0; i-- {
|
||||
opp := len(a) - 1 - i
|
||||
a[i], a[opp] = a[opp], a[i]
|
||||
}
|
||||
}
|
190
internal/registry/registry.go
Normal file
190
internal/registry/registry.go
Normal file
@@ -0,0 +1,190 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"go/types"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/tools/go/packages"
|
||||
)
|
||||
|
||||
// Registry encapsulates types information for the source and mock
|
||||
// destination package. For the mock package, it tracks the list of
|
||||
// imports and ensures there are no conflicts in the imported package
|
||||
// qualifiers.
|
||||
type Registry struct {
|
||||
srcPkg *packages.Package
|
||||
moqPkgPath string
|
||||
aliases map[string]string
|
||||
imports map[string]*Package
|
||||
}
|
||||
|
||||
// New loads the source package info and returns a new instance of
|
||||
// Registry.
|
||||
func New(srcDir, moqPkg string) (*Registry, error) {
|
||||
srcPkg, err := pkgInfoFromPath(
|
||||
srcDir, packages.NeedName|packages.NeedSyntax|packages.NeedTypes|packages.NeedTypesInfo,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("couldn't load source package: %s", err)
|
||||
}
|
||||
|
||||
return &Registry{
|
||||
srcPkg: srcPkg,
|
||||
moqPkgPath: findPkgPath(moqPkg, srcPkg),
|
||||
aliases: parseImportsAliases(srcPkg),
|
||||
imports: make(map[string]*Package),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SrcPkg returns the types info for the source package.
|
||||
func (r Registry) SrcPkg() *types.Package {
|
||||
return r.srcPkg.Types
|
||||
}
|
||||
|
||||
// SrcPkgName returns the name of the source package.
|
||||
func (r Registry) SrcPkgName() string {
|
||||
return r.srcPkg.Name
|
||||
}
|
||||
|
||||
// LookupInterface returns the underlying interface definition of the
|
||||
// given interface name.
|
||||
func (r Registry) LookupInterface(name string) (*types.Interface, error) {
|
||||
obj := r.SrcPkg().Scope().Lookup(name)
|
||||
if obj == nil {
|
||||
return nil, fmt.Errorf("interface not found: %s", name)
|
||||
}
|
||||
|
||||
if !types.IsInterface(obj.Type()) {
|
||||
return nil, fmt.Errorf("%s (%s) is not an interface", name, obj.Type())
|
||||
}
|
||||
|
||||
return obj.Type().Underlying().(*types.Interface).Complete(), nil
|
||||
}
|
||||
|
||||
// MethodScope returns a new MethodScope.
|
||||
func (r *Registry) MethodScope() *MethodScope {
|
||||
return &MethodScope{
|
||||
registry: r,
|
||||
moqPkgPath: r.moqPkgPath,
|
||||
conflicted: map[string]bool{},
|
||||
}
|
||||
}
|
||||
|
||||
// AddImport adds the given package to the set of imports. It generates a
|
||||
// suitable alias if there are any conflicts with previously imported
|
||||
// packages.
|
||||
func (r *Registry) AddImport(pkg *types.Package) *Package {
|
||||
path := stripVendorPath(pkg.Path())
|
||||
if path == r.moqPkgPath {
|
||||
return nil
|
||||
}
|
||||
|
||||
if imprt, ok := r.imports[path]; ok {
|
||||
return imprt
|
||||
}
|
||||
|
||||
imprt := Package{pkg: pkg, Alias: r.aliases[path]}
|
||||
|
||||
if conflict, ok := r.searchImport(imprt.Qualifier()); ok {
|
||||
resolveImportConflict(&imprt, conflict, 0)
|
||||
}
|
||||
|
||||
r.imports[path] = &imprt
|
||||
return &imprt
|
||||
}
|
||||
|
||||
// Imports returns the list of imported packages. The list is sorted by
|
||||
// path.
|
||||
func (r Registry) Imports() []*Package {
|
||||
imports := make([]*Package, 0, len(r.imports))
|
||||
for _, imprt := range r.imports {
|
||||
imports = append(imports, imprt)
|
||||
}
|
||||
sort.Slice(imports, func(i, j int) bool {
|
||||
return imports[i].Path() < imports[j].Path()
|
||||
})
|
||||
return imports
|
||||
}
|
||||
|
||||
func (r Registry) searchImport(name string) (*Package, bool) {
|
||||
for _, imprt := range r.imports {
|
||||
if imprt.Qualifier() == name {
|
||||
return imprt, true
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func pkgInfoFromPath(srcDir string, mode packages.LoadMode) (*packages.Package, error) {
|
||||
pkgs, err := packages.Load(&packages.Config{
|
||||
Mode: mode,
|
||||
Dir: srcDir,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(pkgs) == 0 {
|
||||
return nil, errors.New("package not found")
|
||||
}
|
||||
if len(pkgs) > 1 {
|
||||
return nil, errors.New("found more than one package")
|
||||
}
|
||||
if errs := pkgs[0].Errors; len(errs) != 0 {
|
||||
if len(errs) == 1 {
|
||||
return nil, errs[0]
|
||||
}
|
||||
return nil, fmt.Errorf("%s (and %d more errors)", errs[0], len(errs)-1)
|
||||
}
|
||||
return pkgs[0], nil
|
||||
}
|
||||
|
||||
func findPkgPath(pkgInputVal string, srcPkg *packages.Package) string {
|
||||
if pkgInputVal == "" {
|
||||
return srcPkg.PkgPath
|
||||
}
|
||||
if pkgInDir(srcPkg.PkgPath, pkgInputVal) {
|
||||
return srcPkg.PkgPath
|
||||
}
|
||||
subdirectoryPath := filepath.Join(srcPkg.PkgPath, pkgInputVal)
|
||||
if pkgInDir(subdirectoryPath, pkgInputVal) {
|
||||
return subdirectoryPath
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func pkgInDir(pkgName, dir string) bool {
|
||||
currentPkg, err := pkgInfoFromPath(dir, packages.NeedName)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return currentPkg.Name == pkgName || currentPkg.Name+"_test" == pkgName
|
||||
}
|
||||
|
||||
func parseImportsAliases(pkg *packages.Package) map[string]string {
|
||||
aliases := make(map[string]string)
|
||||
for _, syntax := range pkg.Syntax {
|
||||
for _, imprt := range syntax.Imports {
|
||||
if imprt.Name != nil && imprt.Name.Name != "." {
|
||||
aliases[strings.Trim(imprt.Path.Value, `"`)] = imprt.Name.Name
|
||||
}
|
||||
}
|
||||
}
|
||||
return aliases
|
||||
}
|
||||
|
||||
// resolveImportConflict generates and assigns a unique alias for
|
||||
// packages with conflicting qualifiers.
|
||||
func resolveImportConflict(a, b *Package, lvl int) {
|
||||
u1, u2 := a.uniqueName(lvl), b.uniqueName(lvl)
|
||||
if u1 != u2 {
|
||||
a.Alias, b.Alias = u1, u2
|
||||
return
|
||||
}
|
||||
|
||||
resolveImportConflict(a, b, lvl+1)
|
||||
}
|
123
internal/registry/var.go
Normal file
123
internal/registry/var.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"go/types"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Var represents a method variable/parameter.
|
||||
//
|
||||
// It should be created using a method scope instance.
|
||||
type Var struct {
|
||||
vr *types.Var
|
||||
imports map[string]*Package
|
||||
moqPkgPath string
|
||||
|
||||
Name string
|
||||
}
|
||||
|
||||
// IsSlice returns whether the type (or the underlying type) is a slice.
|
||||
func (v Var) IsSlice() bool {
|
||||
_, ok := v.vr.Type().Underlying().(*types.Slice)
|
||||
return ok
|
||||
}
|
||||
|
||||
// TypeString returns the variable type with the package qualifier in the
|
||||
// format 'pkg.Type'.
|
||||
func (v Var) TypeString() string {
|
||||
return types.TypeString(v.vr.Type(), v.packageQualifier)
|
||||
}
|
||||
|
||||
// packageQualifier is a types.Qualifier.
|
||||
func (v Var) packageQualifier(pkg *types.Package) string {
|
||||
path := stripVendorPath(pkg.Path())
|
||||
if v.moqPkgPath != "" && v.moqPkgPath == path {
|
||||
return ""
|
||||
}
|
||||
|
||||
return v.imports[path].Qualifier()
|
||||
}
|
||||
|
||||
// generateVarName generates a name for the variable using the type
|
||||
// information.
|
||||
//
|
||||
// Examples:
|
||||
// - string -> s
|
||||
// - int -> n
|
||||
// - chan int -> intCh
|
||||
// - []a.MyType -> myTypes
|
||||
// - map[string]int -> stringToInt
|
||||
// - error -> err
|
||||
// - a.MyType -> myType
|
||||
func generateVarName(t types.Type) string {
|
||||
nestedType := func(t types.Type) string {
|
||||
if t, ok := t.(*types.Basic); ok {
|
||||
return deCapitalise(t.String())
|
||||
}
|
||||
return generateVarName(t)
|
||||
}
|
||||
|
||||
switch t := t.(type) {
|
||||
case *types.Named:
|
||||
if t.Obj().Name() == "error" {
|
||||
return "err"
|
||||
}
|
||||
|
||||
name := deCapitalise(t.Obj().Name())
|
||||
if name == t.Obj().Name() {
|
||||
name += "MoqParam"
|
||||
}
|
||||
|
||||
return name
|
||||
|
||||
case *types.Basic:
|
||||
return basicTypeVarName(t)
|
||||
|
||||
case *types.Array:
|
||||
return nestedType(t.Elem()) + "s"
|
||||
|
||||
case *types.Slice:
|
||||
return nestedType(t.Elem()) + "s"
|
||||
|
||||
case *types.Struct: // anonymous struct
|
||||
return "val"
|
||||
|
||||
case *types.Pointer:
|
||||
return generateVarName(t.Elem())
|
||||
|
||||
case *types.Signature:
|
||||
return "fn"
|
||||
|
||||
case *types.Interface: // anonymous interface
|
||||
return "ifaceVal"
|
||||
|
||||
case *types.Map:
|
||||
return nestedType(t.Key()) + "To" + capitalise(nestedType(t.Elem()))
|
||||
|
||||
case *types.Chan:
|
||||
return nestedType(t.Elem()) + "Ch"
|
||||
}
|
||||
|
||||
return "v"
|
||||
}
|
||||
|
||||
func basicTypeVarName(b *types.Basic) string {
|
||||
switch b.Info() {
|
||||
case types.IsBoolean:
|
||||
return "b"
|
||||
|
||||
case types.IsInteger:
|
||||
return "n"
|
||||
|
||||
case types.IsFloat:
|
||||
return "f"
|
||||
|
||||
case types.IsString:
|
||||
return "s"
|
||||
}
|
||||
|
||||
return "v"
|
||||
}
|
||||
|
||||
func capitalise(s string) string { return strings.ToUpper(s[:1]) + s[1:] }
|
||||
func deCapitalise(s string) string { return strings.ToLower(s[:1]) + s[1:] }
|
190
internal/template/template.go
Normal file
190
internal/template/template.go
Normal file
@@ -0,0 +1,190 @@
|
||||
package template
|
||||
|
||||
import (
|
||||
"io"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"github.com/matryer/moq/internal/registry"
|
||||
)
|
||||
|
||||
// Template is the Moq template. It is capable of generating the Moq
|
||||
// implementation for the given template.Data.
|
||||
type Template struct {
|
||||
tmpl *template.Template
|
||||
}
|
||||
|
||||
// New returns a new instance of Template.
|
||||
func New() (Template, error) {
|
||||
tmpl, err := template.New("moq").Funcs(templateFuncs).Parse(moqTemplate)
|
||||
if err != nil {
|
||||
return Template{}, err
|
||||
}
|
||||
|
||||
return Template{tmpl: tmpl}, nil
|
||||
}
|
||||
|
||||
// Execute generates and writes the Moq implementation for the given
|
||||
// data.
|
||||
func (t Template) Execute(w io.Writer, data Data) error {
|
||||
return t.tmpl.Execute(w, data)
|
||||
}
|
||||
|
||||
// moqTemplate is the template for mocked code.
|
||||
// language=GoTemplate
|
||||
var moqTemplate = `// Code generated by moq; DO NOT EDIT.
|
||||
// github.com/matryer/moq
|
||||
|
||||
package {{.PkgName}}
|
||||
|
||||
import (
|
||||
{{- range .Imports}}
|
||||
{{. | ImportStatement}}
|
||||
{{- end}}
|
||||
)
|
||||
|
||||
{{range $i, $mock := .Mocks -}}
|
||||
|
||||
{{- if not $.SkipEnsure -}}
|
||||
// Ensure, that {{.MockName}} does implement {{$.SrcPkgQualifier}}{{.InterfaceName}}.
|
||||
// If this is not the case, regenerate this file with moq.
|
||||
var _ {{$.SrcPkgQualifier}}{{.InterfaceName}} = &{{.MockName}}{}
|
||||
{{- end}}
|
||||
|
||||
// {{.MockName}} is a mock implementation of {{$.SrcPkgQualifier}}{{.InterfaceName}}.
|
||||
//
|
||||
// func TestSomethingThatUses{{.InterfaceName}}(t *testing.T) {
|
||||
//
|
||||
// // make and configure a mocked {{$.SrcPkgQualifier}}{{.InterfaceName}}
|
||||
// mocked{{.InterfaceName}} := &{{.MockName}}{
|
||||
{{- range .Methods}}
|
||||
// {{.Name}}Func: func({{.ArgList}}) {{.ReturnArgTypeList}} {
|
||||
// panic("mock out the {{.Name}} method")
|
||||
// },
|
||||
{{- end}}
|
||||
// }
|
||||
//
|
||||
// // use mocked{{.InterfaceName}} in code that requires {{$.SrcPkgQualifier}}{{.InterfaceName}}
|
||||
// // and then make assertions.
|
||||
//
|
||||
// }
|
||||
type {{.MockName}} struct {
|
||||
{{- range .Methods}}
|
||||
// {{.Name}}Func mocks the {{.Name}} method.
|
||||
{{.Name}}Func func({{.ArgList}}) {{.ReturnArgTypeList}}
|
||||
{{end}}
|
||||
// calls tracks calls to the methods.
|
||||
calls struct {
|
||||
{{- range .Methods}}
|
||||
// {{.Name}} holds details about calls to the {{.Name}} method.
|
||||
{{.Name}} []struct {
|
||||
{{- range .Params}}
|
||||
// {{.Name | Exported}} is the {{.Name}} argument value.
|
||||
{{.Name | Exported}} {{.TypeString}}
|
||||
{{- end}}
|
||||
}
|
||||
{{- end}}
|
||||
}
|
||||
{{- range .Methods}}
|
||||
lock{{.Name}} {{$.Imports | SyncPkgQualifier}}.RWMutex
|
||||
{{- end}}
|
||||
}
|
||||
{{range .Methods}}
|
||||
// {{.Name}} calls {{.Name}}Func.
|
||||
func (mock *{{$mock.MockName}}) {{.Name}}({{.ArgList}}) {{.ReturnArgTypeList}} {
|
||||
{{- if not $.StubImpl}}
|
||||
if mock.{{.Name}}Func == nil {
|
||||
panic("{{$mock.MockName}}.{{.Name}}Func: method is nil but {{$mock.InterfaceName}}.{{.Name}} was just called")
|
||||
}
|
||||
{{- end}}
|
||||
callInfo := struct {
|
||||
{{- range .Params}}
|
||||
{{.Name | Exported}} {{.TypeString}}
|
||||
{{- end}}
|
||||
}{
|
||||
{{- range .Params}}
|
||||
{{.Name | Exported}}: {{.Name}},
|
||||
{{- end}}
|
||||
}
|
||||
mock.lock{{.Name}}.Lock()
|
||||
mock.calls.{{.Name}} = append(mock.calls.{{.Name}}, callInfo)
|
||||
mock.lock{{.Name}}.Unlock()
|
||||
{{- if .Returns}}
|
||||
{{- if $.StubImpl}}
|
||||
if mock.{{.Name}}Func == nil {
|
||||
var (
|
||||
{{- range .Returns}}
|
||||
{{.Name}} {{.TypeString}}
|
||||
{{- end}}
|
||||
)
|
||||
return {{.ReturnArgNameList}}
|
||||
}
|
||||
{{- end}}
|
||||
return mock.{{.Name}}Func({{.ArgCallList}})
|
||||
{{- else}}
|
||||
{{- if $.StubImpl}}
|
||||
if mock.{{.Name}}Func == nil {
|
||||
return
|
||||
}
|
||||
{{- end}}
|
||||
mock.{{.Name}}Func({{.ArgCallList}})
|
||||
{{- end}}
|
||||
}
|
||||
|
||||
// {{.Name}}Calls gets all the calls that were made to {{.Name}}.
|
||||
// Check the length with:
|
||||
// len(mocked{{$mock.InterfaceName}}.{{.Name}}Calls())
|
||||
func (mock *{{$mock.MockName}}) {{.Name}}Calls() []struct {
|
||||
{{- range .Params}}
|
||||
{{.Name | Exported}} {{.TypeString}}
|
||||
{{- end}}
|
||||
} {
|
||||
var calls []struct {
|
||||
{{- range .Params}}
|
||||
{{.Name | Exported}} {{.TypeString}}
|
||||
{{- end}}
|
||||
}
|
||||
mock.lock{{.Name}}.RLock()
|
||||
calls = mock.calls.{{.Name}}
|
||||
mock.lock{{.Name}}.RUnlock()
|
||||
return calls
|
||||
}
|
||||
{{end -}}
|
||||
{{end -}}`
|
||||
|
||||
// This list comes from the golint codebase. Golint will complain about any of
|
||||
// these being mixed-case, like "Id" instead of "ID".
|
||||
var golintInitialisms = []string{
|
||||
"ACL", "API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS",
|
||||
"QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SQL", "SSH", "TCP", "TLS", "TTL", "UDP", "UI", "UID", "UUID", "URI",
|
||||
"URL", "UTF8", "VM", "XML", "XMPP", "XSRF", "XSS",
|
||||
}
|
||||
|
||||
var templateFuncs = template.FuncMap{
|
||||
"ImportStatement": func(imprt *registry.Package) string {
|
||||
if imprt.Alias == "" {
|
||||
return `"` + imprt.Path() + `"`
|
||||
}
|
||||
return imprt.Alias + ` "` + imprt.Path() + `"`
|
||||
},
|
||||
"SyncPkgQualifier": func(imports []*registry.Package) string {
|
||||
for _, imprt := range imports {
|
||||
if imprt.Path() == "sync" {
|
||||
return imprt.Qualifier()
|
||||
}
|
||||
}
|
||||
|
||||
return "sync"
|
||||
},
|
||||
"Exported": func(s string) string {
|
||||
if s == "" {
|
||||
return ""
|
||||
}
|
||||
for _, initialism := range golintInitialisms {
|
||||
if strings.ToUpper(s) == initialism {
|
||||
return initialism
|
||||
}
|
||||
}
|
||||
return strings.ToUpper(s[0:1]) + s[1:]
|
||||
},
|
||||
}
|
125
internal/template/template_data.go
Normal file
125
internal/template/template_data.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package template
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/matryer/moq/internal/registry"
|
||||
)
|
||||
|
||||
// Data is the template data used to render the Moq template.
|
||||
type Data struct {
|
||||
PkgName string
|
||||
SrcPkgQualifier string
|
||||
Imports []*registry.Package
|
||||
Mocks []MockData
|
||||
StubImpl bool
|
||||
SkipEnsure bool
|
||||
}
|
||||
|
||||
// MocksSomeMethod returns true of any one of the Mocks has at least 1
|
||||
// method.
|
||||
func (d Data) MocksSomeMethod() bool {
|
||||
for _, m := range d.Mocks {
|
||||
if len(m.Methods) > 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// MockData is the data used to generate a mock for some interface.
|
||||
type MockData struct {
|
||||
InterfaceName string
|
||||
MockName string
|
||||
Methods []MethodData
|
||||
}
|
||||
|
||||
// MethodData is the data which represents a method on some interface.
|
||||
type MethodData struct {
|
||||
Name string
|
||||
Params []ParamData
|
||||
Returns []ParamData
|
||||
}
|
||||
|
||||
// ArgList is the string representation of method parameters, ex:
|
||||
// 's string, n int, foo bar.Baz'.
|
||||
func (m MethodData) ArgList() string {
|
||||
params := make([]string, len(m.Params))
|
||||
for i, p := range m.Params {
|
||||
params[i] = p.MethodArg()
|
||||
}
|
||||
return strings.Join(params, ", ")
|
||||
}
|
||||
|
||||
// ArgCallList is the string representation of method call parameters,
|
||||
// ex: 's, n, foo'. In case of a last variadic parameter, it will be of
|
||||
// the format 's, n, foos...'
|
||||
func (m MethodData) ArgCallList() string {
|
||||
params := make([]string, len(m.Params))
|
||||
for i, p := range m.Params {
|
||||
params[i] = p.CallName()
|
||||
}
|
||||
return strings.Join(params, ", ")
|
||||
}
|
||||
|
||||
// ReturnArgTypeList is the string representation of method return
|
||||
// types, ex: 'bar.Baz', '(string, error)'.
|
||||
func (m MethodData) ReturnArgTypeList() string {
|
||||
params := make([]string, len(m.Returns))
|
||||
for i, p := range m.Returns {
|
||||
params[i] = p.TypeString()
|
||||
}
|
||||
if len(m.Returns) > 1 {
|
||||
return fmt.Sprintf("(%s)", strings.Join(params, ", "))
|
||||
}
|
||||
return strings.Join(params, ", ")
|
||||
}
|
||||
|
||||
// ReturnArgNameList is the string representation of values being
|
||||
// returned from the method, ex: 'foo', 's, err'.
|
||||
func (m MethodData) ReturnArgNameList() string {
|
||||
params := make([]string, len(m.Returns))
|
||||
for i, p := range m.Returns {
|
||||
params[i] = p.Name()
|
||||
}
|
||||
return strings.Join(params, ", ")
|
||||
}
|
||||
|
||||
// ParamData is the data which represents a parameter to some method of
|
||||
// an interface.
|
||||
type ParamData struct {
|
||||
Var *registry.Var
|
||||
Variadic bool
|
||||
}
|
||||
|
||||
// Name returns the name of the parameter.
|
||||
func (p ParamData) Name() string {
|
||||
return p.Var.Name
|
||||
}
|
||||
|
||||
// MethodArg is the representation of the parameter in the function
|
||||
// signature, ex: 'name a.Type'.
|
||||
func (p ParamData) MethodArg() string {
|
||||
if p.Variadic {
|
||||
return fmt.Sprintf("%s ...%s", p.Name(), p.TypeString()[2:])
|
||||
}
|
||||
return fmt.Sprintf("%s %s", p.Name(), p.TypeString())
|
||||
}
|
||||
|
||||
// CallName returns the string representation of the parameter to be
|
||||
// used for a method call. For a variadic paramter, it will be of the
|
||||
// format 'foos...'.
|
||||
func (p ParamData) CallName() string {
|
||||
if p.Variadic {
|
||||
return p.Name() + "..."
|
||||
}
|
||||
return p.Name()
|
||||
}
|
||||
|
||||
// TypeString returns the string representation of the type of the
|
||||
// parameter.
|
||||
func (p ParamData) TypeString() string {
|
||||
return p.Var.TypeString()
|
||||
}
|
55
internal/template/template_test.go
Normal file
55
internal/template/template_test.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package template
|
||||
|
||||
import (
|
||||
"go/types"
|
||||
"testing"
|
||||
|
||||
"github.com/matryer/moq/internal/registry"
|
||||
)
|
||||
|
||||
func TestTemplateFuncs(t *testing.T) {
|
||||
t.Run("Exported", func(t *testing.T) {
|
||||
f := templateFuncs["Exported"].(func(string) string)
|
||||
if f("") != "" {
|
||||
t.Errorf("Exported(...) want: ``; got: `%s`", f(""))
|
||||
}
|
||||
if f("var") != "Var" {
|
||||
t.Errorf("Exported(...) want: `Var`; got: `%s`", f("var"))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ImportStatement", func(t *testing.T) {
|
||||
f := templateFuncs["ImportStatement"].(func(*registry.Package) string)
|
||||
pkg := registry.NewPackage(types.NewPackage("xyz", "xyz"))
|
||||
if f(pkg) != `"xyz"` {
|
||||
t.Errorf("ImportStatement(...): want: `\"xyz\"`; got: `%s`", f(pkg))
|
||||
}
|
||||
|
||||
pkg.Alias = "x"
|
||||
if f(pkg) != `x "xyz"` {
|
||||
t.Errorf("ImportStatement(...): want: `x \"xyz\"`; got: `%s`", f(pkg))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SyncPkgQualifier", func(t *testing.T) {
|
||||
f := templateFuncs["SyncPkgQualifier"].(func([]*registry.Package) string)
|
||||
if f(nil) != "sync" {
|
||||
t.Errorf("SyncPkgQualifier(...): want: `sync`; got: `%s`", f(nil))
|
||||
}
|
||||
imports := []*registry.Package{
|
||||
registry.NewPackage(types.NewPackage("sync", "sync")),
|
||||
registry.NewPackage(types.NewPackage("github.com/some/module", "module")),
|
||||
}
|
||||
if f(imports) != "sync" {
|
||||
t.Errorf("SyncPkgQualifier(...): want: `sync`; got: `%s`", f(imports))
|
||||
}
|
||||
|
||||
syncPkg := registry.NewPackage(types.NewPackage("sync", "sync"))
|
||||
syncPkg.Alias = "stdsync"
|
||||
otherSyncPkg := registry.NewPackage(types.NewPackage("github.com/someother/sync", "sync"))
|
||||
imports = []*registry.Package{otherSyncPkg, syncPkg}
|
||||
if f(imports) != "stdsync" {
|
||||
t.Errorf("SyncPkgQualifier(...): want: `stdsync`; got: `%s`", f(imports))
|
||||
}
|
||||
})
|
||||
}
|
Reference in New Issue
Block a user