diff --git a/main.go b/main.go index d525994..fa317a3 100644 --- a/main.go +++ b/main.go @@ -2,12 +2,12 @@ package main import ( "bytes" - "errors" "flag" "fmt" "io" "io/ioutil" "os" + "errors" "github.com/matryer/moq/pkg/moq" ) diff --git a/pkg/moq/importer.go b/pkg/moq/importer.go deleted file mode 100644 index 07747b8..0000000 --- a/pkg/moq/importer.go +++ /dev/null @@ -1,186 +0,0 @@ -package moq - -// taken from https://github.com/ernesto-jimenez/gogen -// Copyright (c) 2015 Ernesto Jiménez - -import ( - "fmt" - "go/ast" - "go/importer" - "go/parser" - "go/token" - "go/types" - "io/ioutil" - "os" - "path" - "path/filepath" - "strings" -) - -type customImporter struct { - source string - imported map[string]*types.Package - base types.Importer - skipTestFiles bool -} - -func (i *customImporter) Import(path string) (*types.Package, error) { - var err error - if path == "" || path[0] == '.' { - path, err = filepath.Abs(filepath.Clean(path)) - if err != nil { - return nil, err - } - path = stripGopath(path) - } - if pkg, ok := i.imported[path]; ok { - return pkg, nil - } - pkg, err := i.fsPkg(path) - if err != nil { - return nil, err - } - i.imported[path] = pkg - return pkg, nil -} - -func gopathDir(source, pkg string) (string, error) { - // check vendor directory - vendorPath, found := vendorPath(source, pkg) - if found { - return vendorPath, nil - } - for _, gopath := range gopaths() { - absPath, err := filepath.Abs(path.Join(gopath, "src", pkg)) - if err != nil { - return "", err - } - if dir, err := os.Stat(absPath); err == nil && dir.IsDir() { - return absPath, nil - } - } - return "", fmt.Errorf("%s not in $GOPATH or %s", pkg, path.Join(source, "vendor")) -} - -func vendorPath(source, pkg string) (string, bool) { - for { - if isGopath(source) { - return "", false - } - var err error - source, err = filepath.Abs(source) - if err != nil { - return "", false - } - vendorPath, err := filepath.Abs(path.Join(source, "vendor", pkg)) - if err != nil { - return "", false - } - if dir, err := os.Stat(vendorPath); err == nil && dir.IsDir() { - return vendorPath, true - } - source = filepath.Dir(source) - } -} - -func removeGopath(p string) string { - for _, gopath := range gopaths() { - p = strings.Replace(p, path.Join(gopath, "src")+"/", "", 1) - } - return p -} - -func gopaths() []string { - return strings.Split(os.Getenv("GOPATH"), string(filepath.ListSeparator)) -} - -func isGopath(path string) bool { - for _, p := range gopaths() { - if p == path { - return true - } - } - return false -} - -func (i *customImporter) fsPkg(pkg string) (*types.Package, error) { - dir, err := gopathDir(i.source, pkg) - if err != nil { - return importOrErr(i.base, pkg, err) - } - - dirFiles, err := ioutil.ReadDir(dir) - if err != nil { - return importOrErr(i.base, pkg, err) - } - - fset := token.NewFileSet() - var files []*ast.File - for _, fileInfo := range dirFiles { - if fileInfo.IsDir() { - continue - } - n := fileInfo.Name() - if path.Ext(fileInfo.Name()) != ".go" { - continue - } - if i.skipTestFiles && strings.Contains(fileInfo.Name(), "_test.go") { - continue - } - file := path.Join(dir, n) - src, err := ioutil.ReadFile(file) - if err != nil { - return nil, err - } - f, err := parser.ParseFile(fset, file, src, 0) - if err != nil { - return nil, err - } - files = append(files, f) - } - conf := types.Config{ - Importer: i, - } - p, err := conf.Check(pkg, fset, files, nil) - - if err != nil { - return importOrErr(i.base, pkg, err) - } - return p, nil -} - -func importOrErr(base types.Importer, pkg string, err error) (*types.Package, error) { - p, impErr := base.Import(pkg) - if impErr != nil { - return nil, err - } - return p, nil -} - -// newImporter returns an importer that will try to import code from gopath before using go/importer.Default and skipping test files -func newImporter(source string) types.Importer { - return &customImporter{ - source: source, - imported: make(map[string]*types.Package), - base: importer.Default(), - skipTestFiles: true, - } -} - -// // DefaultWithTestFiles same as Default but it parses test files too -// func DefaultWithTestFiles() types.Importer { -// return &customImporter{ -// imported: make(map[string]*types.Package), -// base: importer.Default(), -// skipTestFiles: false, -// } -// } - -// stripGopath teks the directory to a package and remove the gopath to get the -// canonical package name. -func stripGopath(p string) string { - for _, gopath := range gopaths() { - p = strings.TrimPrefix(p, path.Join(gopath, "src")+"/") - } - return p -} diff --git a/pkg/moq/moq.go b/pkg/moq/moq.go index ae37714..952bce1 100644 --- a/pkg/moq/moq.go +++ b/pkg/moq/moq.go @@ -11,8 +11,12 @@ import ( "go/types" "io" "os" + "path" + "path/filepath" "strings" "text/template" + + "golang.org/x/tools/go/loader" ) // This list comes from the golint codebase. Golint will complain about any of @@ -75,11 +79,13 @@ func New(src, packageName string) (*Mocker, error) { noTestFiles := func(i os.FileInfo) bool { return !strings.HasSuffix(i.Name(), "_test.go") } + pkgs, err := parser.ParseDir(fset, src, noTestFiles, parser.SpuriousErrors) if err != nil { return nil, err } if len(packageName) == 0 { + for pkgName := range pkgs { if strings.Contains(pkgName, "_test") { continue @@ -110,57 +116,56 @@ func (m *Mocker) Mock(w io.Writer, name ...string) error { if len(name) == 0 { return errors.New("must specify one interface") } + + pkgInfo, err := m.pkgInfoFromPath(m.src) + if err != nil { + return err + } + doc := doc{ PackageName: m.pkgName, Imports: moqImports, } + mocksMethods := false - for _, pkg := range m.pkgs { - i := 0 - files := make([]*ast.File, len(pkg.Files)) - for _, file := range pkg.Files { - files[i] = file - i++ + + tpkg := pkgInfo.Pkg + for _, n := range name { + iface := tpkg.Scope().Lookup(n) + if iface == nil { + return fmt.Errorf("cannot find interface %s", n) } - conf := types.Config{Importer: newImporter(m.src)} - tpkg, err := conf.Check(m.src, m.fset, files, nil) - if err != nil { - return err + if !types.IsInterface(iface.Type()) { + return fmt.Errorf("%s (%s) not an interface", n, iface.Type().String()) } - for _, n := range name { - iface := tpkg.Scope().Lookup(n) - if iface == nil { - return fmt.Errorf("cannot find interface %s", n) - } - if !types.IsInterface(iface.Type()) { - return fmt.Errorf("%s (%s) not an interface", n, iface.Type().String()) - } - iiface := iface.Type().Underlying().(*types.Interface).Complete() - obj := obj{ - InterfaceName: n, - } - for i := 0; i < iiface.NumMethods(); i++ { - mocksMethods = true - meth := iiface.Method(i) - sig := meth.Type().(*types.Signature) - method := &method{ - Name: meth.Name(), - } - obj.Methods = append(obj.Methods, method) - method.Params = m.extractArgs(sig, sig.Params(), "in%d") - method.Returns = m.extractArgs(sig, sig.Results(), "out%d") - } - doc.Objects = append(doc.Objects, obj) + iiface := iface.Type().Underlying().(*types.Interface).Complete() + obj := obj{ + InterfaceName: n, } + for i := 0; i < iiface.NumMethods(); i++ { + mocksMethods = true + meth := iiface.Method(i) + sig := meth.Type().(*types.Signature) + method := &method{ + Name: meth.Name(), + } + obj.Methods = append(obj.Methods, method) + method.Params = m.extractArgs(sig, sig.Params(), "in%d") + method.Returns = m.extractArgs(sig, sig.Results(), "out%d") + } + doc.Objects = append(doc.Objects, obj) } + if mocksMethods { doc.Imports = append(doc.Imports, "sync") } + for pkgToImport := range m.imports { - doc.Imports = append(doc.Imports, pkgToImport) + doc.Imports = append(doc.Imports, stripVendorPath(pkgToImport)) } + var buf bytes.Buffer - err := m.tmpl.Execute(&buf, doc) + err = m.tmpl.Execute(&buf, doc) if err != nil { return err } @@ -211,6 +216,32 @@ func (m *Mocker) extractArgs(sig *types.Signature, list *types.Tuple, nameFormat return params } +func (*Mocker) pkgInfoFromPath(src string) (*loader.PackageInfo, error) { + + abs, err := filepath.Abs(src) + if err != nil { + return nil, err + } + pkgFull := stripGopath(abs) + + conf := loader.Config{ + ParserMode: parser.SpuriousErrors, + Cwd: src, + } + conf.Import(pkgFull) + lprog, err := conf.Load() + if err != nil { + return nil, err + } + + pkgInfo := lprog.Package(pkgFull) + if pkgInfo == nil { + return nil, errors.New("package was nil") + } + + return pkgInfo, nil +} + type doc struct { PackageName string Objects []obj @@ -291,3 +322,31 @@ var templateFuncs = template.FuncMap{ return strings.ToUpper(s[0:1]) + s[1:] }, } + +// stripVendorPath strips the vendor dir prefix from a package path. +// For example we might encounter an absolute path like +// github.com/foo/bar/vendor/github.com/pkg/errors which is resolved +// to github.com/pkg/errors. +func stripVendorPath(p string) string { + parts := strings.Split(p, "/vendor/") + if len(parts) == 1 { + return p + } + return strings.TrimLeft(path.Join(parts[1:]...), "/") +} + +// stripGopath takes the directory to a package and remove the gopath to get the +// canonical package name. +// +// taken from https://github.com/ernesto-jimenez/gogen +// Copyright (c) 2015 Ernesto Jiménez +func stripGopath(p string) string { + for _, gopath := range gopaths() { + p = strings.TrimPrefix(p, path.Join(gopath, "src")+"/") + } + return p +} + +func gopaths() []string { + return strings.Split(os.Getenv("GOPATH"), string(filepath.ListSeparator)) +}