moq/pkg/moq/moq.go
Suhas Karanth 83ab8dd79f
Add -stub flag to stub func implementation (#132)
When a mock is generated with the flag enabled, it introduces the
following changes:
- Does not panic on calling the method without a mock implementation.
- Return zero values iff the implementation is not provided and the
  method has return parameters.

Co-authored-by: Scott Leuthaeuser <scott_leuthaeuser@homedepot.com>
2020-08-16 12:24:12 +05:30

364 lines
8.0 KiB
Go

package moq
import (
"bytes"
"errors"
"fmt"
"go/build"
"go/types"
"io"
"os"
"path"
"path/filepath"
"strconv"
"strings"
"text/template"
"golang.org/x/tools/go/packages"
)
// Mocker can generate mock structs.
type Mocker struct {
srcPkg *packages.Package
tmpl *template.Template
pkgName string
pkgPath string
fmter func(src []byte) ([]byte, error)
stubImpl bool
imports map[string]bool
}
// Config specifies details about how interfaces should be mocked.
// SrcDir is the only field which needs be specified.
type Config struct {
SrcDir string
PkgName string
Formatter string
StubImpl bool
}
// New makes a new Mocker for the specified package directory.
func New(conf Config) (*Mocker, error) {
srcPkg, err := pkgInfoFromPath(conf.SrcDir, packages.NeedName|packages.NeedTypes|packages.NeedTypesInfo)
if err != nil {
return nil, fmt.Errorf("couldn't load source package: %s", err)
}
pkgName := conf.PkgName
if pkgName == "" {
pkgName = srcPkg.Name
}
pkgPath, err := findPkgPath(conf.PkgName, srcPkg)
if err != nil {
return nil, fmt.Errorf("couldn't load mock package: %s", err)
}
tmpl, err := template.New("moq").Funcs(templateFuncs).Parse(moqTemplate)
if err != nil {
return nil, err
}
fmter := gofmt
if conf.Formatter == "goimports" {
fmter = goimports
}
return &Mocker{
tmpl: tmpl,
srcPkg: srcPkg,
pkgName: pkgName,
pkgPath: pkgPath,
fmter: fmter,
stubImpl: conf.StubImpl,
imports: make(map[string]bool),
}, nil
}
func findPkgPath(pkgInputVal string, srcPkg *packages.Package) (string, error) {
if pkgInputVal == "" {
return srcPkg.PkgPath, nil
}
if pkgInDir(".", pkgInputVal) {
return ".", nil
}
if pkgInDir(srcPkg.PkgPath, pkgInputVal) {
return srcPkg.PkgPath, nil
}
subdirectoryPath := filepath.Join(srcPkg.PkgPath, pkgInputVal)
if pkgInDir(subdirectoryPath, pkgInputVal) {
return subdirectoryPath, nil
}
return "", nil
}
func pkgInDir(pkgName, dir string) bool {
currentPkg, err := pkgInfoFromPath(dir, packages.NeedName)
if err != nil {
return false
}
return currentPkg.Name == pkgName || currentPkg.Name+"_test" == pkgName
}
// Mock generates a mock for the specified interface name.
func (m *Mocker) Mock(w io.Writer, names ...string) error {
if len(names) == 0 {
return errors.New("must specify one interface")
}
doc := doc{
PackageName: m.pkgName,
Imports: moqImports,
StubImpl: m.stubImpl,
}
mocksMethods := false
tpkg := m.srcPkg.Types
for _, name := range names {
n, mockName := parseInterfaceName(name)
iface := tpkg.Scope().Lookup(n)
if iface == nil {
return fmt.Errorf("cannot find interface %s", n)
}
if !types.IsInterface(iface.Type()) {
return fmt.Errorf("%s (%s) not an interface", n, iface.Type().String())
}
iiface := iface.Type().Underlying().(*types.Interface).Complete()
obj := obj{
InterfaceName: n,
MockName: mockName,
}
for i := 0; i < iiface.NumMethods(); i++ {
mocksMethods = true
meth := iiface.Method(i)
sig := meth.Type().(*types.Signature)
method := &method{
Name: meth.Name(),
}
obj.Methods = append(obj.Methods, method)
method.Params, method.Returns = m.extractArgs(sig)
}
doc.Objects = append(doc.Objects, obj)
}
if mocksMethods {
doc.Imports = append(doc.Imports, "sync")
}
for pkgToImport := range m.imports {
doc.Imports = append(doc.Imports, stripVendorPath(pkgToImport))
}
if tpkg.Name() != m.pkgName {
doc.SourcePackagePrefix = tpkg.Name() + "."
doc.Imports = append(doc.Imports, stripVendorPath(tpkg.Path()))
}
var buf bytes.Buffer
err := m.tmpl.Execute(&buf, doc)
if err != nil {
return err
}
formatted, err := m.fmter(buf.Bytes())
if err != nil {
return err
}
if _, err := w.Write(formatted); err != nil {
return err
}
return nil
}
func (m *Mocker) packageQualifier(pkg *types.Package) string {
if m.pkgPath != "" && m.pkgPath == pkg.Path() {
return ""
}
path := pkg.Path()
if pkg.Path() == "." {
wd, err := os.Getwd()
if err == nil {
path = stripGopath(wd)
}
}
m.imports[path] = true
return pkg.Name()
}
func (m *Mocker) extractArgs(sig *types.Signature) (params, results []*param) {
pp := sig.Params()
for i := 0; i < pp.Len(); i++ {
p := m.buildParam(pp.At(i), "in"+strconv.Itoa(i+1))
// check for final variadic argument
p.Variadic = sig.Variadic() && i == pp.Len()-1 && p.Type[0:2] == "[]"
params = append(params, p)
}
rr := sig.Results()
for i := 0; i < rr.Len(); i++ {
results = append(results, m.buildParam(rr.At(i), "out"+strconv.Itoa(i+1)))
}
return
}
func (m *Mocker) buildParam(v *types.Var, fallbackName string) *param {
name := v.Name()
if name == "" || name == "_" {
name = fallbackName
}
typ := types.TypeString(v.Type(), m.packageQualifier)
return &param{Name: name, Type: typ}
}
func pkgInfoFromPath(srcDir string, mode packages.LoadMode) (*packages.Package, error) {
pkgs, err := packages.Load(&packages.Config{
Mode: mode,
Dir: srcDir,
})
if err != nil {
return nil, err
}
if len(pkgs) == 0 {
return nil, errors.New("No packages found")
}
if len(pkgs) > 1 {
return nil, errors.New("More than one package was found")
}
if errs := pkgs[0].Errors; len(errs) != 0 {
if len(errs) == 1 {
return nil, errs[0]
}
return nil, fmt.Errorf("%s (and %d more errors)", errs[0], len(errs)-1)
}
return pkgs[0], nil
}
func parseInterfaceName(name string) (ifaceName, mockName string) {
parts := strings.SplitN(name, ":", 2)
ifaceName = parts[0]
mockName = ifaceName + "Mock"
if len(parts) == 2 {
mockName = parts[1]
}
return
}
type doc struct {
PackageName string
SourcePackagePrefix string
Objects []obj
Imports []string
StubImpl bool
}
type obj struct {
InterfaceName string
MockName string
Methods []*method
}
type method struct {
Name string
Params []*param
Returns []*param
}
func (m *method) Arglist() string {
params := make([]string, len(m.Params))
for i, p := range m.Params {
params[i] = p.String()
}
return strings.Join(params, ", ")
}
func (m *method) ArgCallList() string {
params := make([]string, len(m.Params))
for i, p := range m.Params {
params[i] = p.CallName()
}
return strings.Join(params, ", ")
}
func (m *method) ReturnArgTypeList() string {
params := make([]string, len(m.Returns))
for i, p := range m.Returns {
params[i] = p.TypeString()
}
if len(m.Returns) > 1 {
return fmt.Sprintf("(%s)", strings.Join(params, ", "))
}
return strings.Join(params, ", ")
}
func (m *method) ReturnArgNameList() string {
params := make([]string, len(m.Returns))
for i, p := range m.Returns {
params[i] = p.Name
}
return strings.Join(params, ", ")
}
type param struct {
Name string
Type string
Variadic bool
}
func (p param) String() string {
return fmt.Sprintf("%s %s", p.Name, p.TypeString())
}
func (p param) CallName() string {
if p.Variadic {
return p.Name + "..."
}
return p.Name
}
func (p param) TypeString() string {
if p.Variadic {
return "..." + p.Type[2:]
}
return p.Type
}
var templateFuncs = template.FuncMap{
"Exported": func(s string) string {
if s == "" {
return ""
}
for _, initialism := range golintInitialisms {
if strings.ToUpper(s) == initialism {
return initialism
}
}
return strings.ToUpper(s[0:1]) + s[1:]
},
}
// stripVendorPath strips the vendor dir prefix from a package path.
// For example we might encounter an absolute path like
// github.com/foo/bar/vendor/github.com/pkg/errors which is resolved
// to github.com/pkg/errors.
func stripVendorPath(p string) string {
parts := strings.Split(p, "/vendor/")
if len(parts) == 1 {
return p
}
return strings.TrimLeft(path.Join(parts[1:]...), "/")
}
// stripGopath takes the directory to a package and removes the
// $GOPATH/src path to get the canonical package name.
func stripGopath(p string) string {
for _, srcDir := range build.Default.SrcDirs() {
rel, err := filepath.Rel(srcDir, p)
if err != nil || strings.HasPrefix(rel, "..") {
continue
}
return filepath.ToSlash(rel)
}
return p
}