Use x/tools/go/loader to load packages instead of custom importer.

This commit is contained in:
Stefan Warman 2018-01-11 17:48:30 +01:00
parent c80e9a745b
commit 565e1649f5
3 changed files with 96 additions and 223 deletions

View File

@ -2,12 +2,12 @@ package main
import ( import (
"bytes" "bytes"
"errors"
"flag" "flag"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"os" "os"
"errors"
"github.com/matryer/moq/pkg/moq" "github.com/matryer/moq/pkg/moq"
) )

View File

@ -1,186 +0,0 @@
package moq
// taken from https://github.com/ernesto-jimenez/gogen
// Copyright (c) 2015 Ernesto Jiménez
import (
"fmt"
"go/ast"
"go/importer"
"go/parser"
"go/token"
"go/types"
"io/ioutil"
"os"
"path"
"path/filepath"
"strings"
)
type customImporter struct {
source string
imported map[string]*types.Package
base types.Importer
skipTestFiles bool
}
func (i *customImporter) Import(path string) (*types.Package, error) {
var err error
if path == "" || path[0] == '.' {
path, err = filepath.Abs(filepath.Clean(path))
if err != nil {
return nil, err
}
path = stripGopath(path)
}
if pkg, ok := i.imported[path]; ok {
return pkg, nil
}
pkg, err := i.fsPkg(path)
if err != nil {
return nil, err
}
i.imported[path] = pkg
return pkg, nil
}
func gopathDir(source, pkg string) (string, error) {
// check vendor directory
vendorPath, found := vendorPath(source, pkg)
if found {
return vendorPath, nil
}
for _, gopath := range gopaths() {
absPath, err := filepath.Abs(path.Join(gopath, "src", pkg))
if err != nil {
return "", err
}
if dir, err := os.Stat(absPath); err == nil && dir.IsDir() {
return absPath, nil
}
}
return "", fmt.Errorf("%s not in $GOPATH or %s", pkg, path.Join(source, "vendor"))
}
func vendorPath(source, pkg string) (string, bool) {
for {
if isGopath(source) {
return "", false
}
var err error
source, err = filepath.Abs(source)
if err != nil {
return "", false
}
vendorPath, err := filepath.Abs(path.Join(source, "vendor", pkg))
if err != nil {
return "", false
}
if dir, err := os.Stat(vendorPath); err == nil && dir.IsDir() {
return vendorPath, true
}
source = filepath.Dir(source)
}
}
func removeGopath(p string) string {
for _, gopath := range gopaths() {
p = strings.Replace(p, path.Join(gopath, "src")+"/", "", 1)
}
return p
}
func gopaths() []string {
return strings.Split(os.Getenv("GOPATH"), string(filepath.ListSeparator))
}
func isGopath(path string) bool {
for _, p := range gopaths() {
if p == path {
return true
}
}
return false
}
func (i *customImporter) fsPkg(pkg string) (*types.Package, error) {
dir, err := gopathDir(i.source, pkg)
if err != nil {
return importOrErr(i.base, pkg, err)
}
dirFiles, err := ioutil.ReadDir(dir)
if err != nil {
return importOrErr(i.base, pkg, err)
}
fset := token.NewFileSet()
var files []*ast.File
for _, fileInfo := range dirFiles {
if fileInfo.IsDir() {
continue
}
n := fileInfo.Name()
if path.Ext(fileInfo.Name()) != ".go" {
continue
}
if i.skipTestFiles && strings.Contains(fileInfo.Name(), "_test.go") {
continue
}
file := path.Join(dir, n)
src, err := ioutil.ReadFile(file)
if err != nil {
return nil, err
}
f, err := parser.ParseFile(fset, file, src, 0)
if err != nil {
return nil, err
}
files = append(files, f)
}
conf := types.Config{
Importer: i,
}
p, err := conf.Check(pkg, fset, files, nil)
if err != nil {
return importOrErr(i.base, pkg, err)
}
return p, nil
}
func importOrErr(base types.Importer, pkg string, err error) (*types.Package, error) {
p, impErr := base.Import(pkg)
if impErr != nil {
return nil, err
}
return p, nil
}
// newImporter returns an importer that will try to import code from gopath before using go/importer.Default and skipping test files
func newImporter(source string) types.Importer {
return &customImporter{
source: source,
imported: make(map[string]*types.Package),
base: importer.Default(),
skipTestFiles: true,
}
}
// // DefaultWithTestFiles same as Default but it parses test files too
// func DefaultWithTestFiles() types.Importer {
// return &customImporter{
// imported: make(map[string]*types.Package),
// base: importer.Default(),
// skipTestFiles: false,
// }
// }
// stripGopath teks the directory to a package and remove the gopath to get the
// canonical package name.
func stripGopath(p string) string {
for _, gopath := range gopaths() {
p = strings.TrimPrefix(p, path.Join(gopath, "src")+"/")
}
return p
}

View File

@ -11,8 +11,12 @@ import (
"go/types" "go/types"
"io" "io"
"os" "os"
"path"
"path/filepath"
"strings" "strings"
"text/template" "text/template"
"golang.org/x/tools/go/loader"
) )
// This list comes from the golint codebase. Golint will complain about any of // This list comes from the golint codebase. Golint will complain about any of
@ -75,11 +79,13 @@ func New(src, packageName string) (*Mocker, error) {
noTestFiles := func(i os.FileInfo) bool { noTestFiles := func(i os.FileInfo) bool {
return !strings.HasSuffix(i.Name(), "_test.go") return !strings.HasSuffix(i.Name(), "_test.go")
} }
pkgs, err := parser.ParseDir(fset, src, noTestFiles, parser.SpuriousErrors) pkgs, err := parser.ParseDir(fset, src, noTestFiles, parser.SpuriousErrors)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(packageName) == 0 { if len(packageName) == 0 {
for pkgName := range pkgs { for pkgName := range pkgs {
if strings.Contains(pkgName, "_test") { if strings.Contains(pkgName, "_test") {
continue continue
@ -110,23 +116,20 @@ func (m *Mocker) Mock(w io.Writer, name ...string) error {
if len(name) == 0 { if len(name) == 0 {
return errors.New("must specify one interface") return errors.New("must specify one interface")
} }
pkgInfo, err := m.pkgInfoFromPath(m.src)
if err != nil {
return err
}
doc := doc{ doc := doc{
PackageName: m.pkgName, PackageName: m.pkgName,
Imports: moqImports, Imports: moqImports,
} }
mocksMethods := false mocksMethods := false
for _, pkg := range m.pkgs {
i := 0 tpkg := pkgInfo.Pkg
files := make([]*ast.File, len(pkg.Files))
for _, file := range pkg.Files {
files[i] = file
i++
}
conf := types.Config{Importer: newImporter(m.src)}
tpkg, err := conf.Check(m.src, m.fset, files, nil)
if err != nil {
return err
}
for _, n := range name { for _, n := range name {
iface := tpkg.Scope().Lookup(n) iface := tpkg.Scope().Lookup(n)
if iface == nil { if iface == nil {
@ -152,15 +155,17 @@ func (m *Mocker) Mock(w io.Writer, name ...string) error {
} }
doc.Objects = append(doc.Objects, obj) doc.Objects = append(doc.Objects, obj)
} }
}
if mocksMethods { if mocksMethods {
doc.Imports = append(doc.Imports, "sync") doc.Imports = append(doc.Imports, "sync")
} }
for pkgToImport := range m.imports { for pkgToImport := range m.imports {
doc.Imports = append(doc.Imports, pkgToImport) doc.Imports = append(doc.Imports, stripVendorPath(pkgToImport))
} }
var buf bytes.Buffer var buf bytes.Buffer
err := m.tmpl.Execute(&buf, doc) err = m.tmpl.Execute(&buf, doc)
if err != nil { if err != nil {
return err return err
} }
@ -211,6 +216,32 @@ func (m *Mocker) extractArgs(sig *types.Signature, list *types.Tuple, nameFormat
return params return params
} }
func (*Mocker) pkgInfoFromPath(src string) (*loader.PackageInfo, error) {
abs, err := filepath.Abs(src)
if err != nil {
return nil, err
}
pkgFull := stripGopath(abs)
conf := loader.Config{
ParserMode: parser.SpuriousErrors,
Cwd: src,
}
conf.Import(pkgFull)
lprog, err := conf.Load()
if err != nil {
return nil, err
}
pkgInfo := lprog.Package(pkgFull)
if pkgInfo == nil {
return nil, errors.New("package was nil")
}
return pkgInfo, nil
}
type doc struct { type doc struct {
PackageName string PackageName string
Objects []obj Objects []obj
@ -291,3 +322,31 @@ var templateFuncs = template.FuncMap{
return strings.ToUpper(s[0:1]) + s[1:] return strings.ToUpper(s[0:1]) + s[1:]
}, },
} }
// stripVendorPath strips the vendor dir prefix from a package path.
// For example we might encounter an absolute path like
// github.com/foo/bar/vendor/github.com/pkg/errors which is resolved
// to github.com/pkg/errors.
func stripVendorPath(p string) string {
parts := strings.Split(p, "/vendor/")
if len(parts) == 1 {
return p
}
return strings.TrimLeft(path.Join(parts[1:]...), "/")
}
// stripGopath takes the directory to a package and remove the gopath to get the
// canonical package name.
//
// taken from https://github.com/ernesto-jimenez/gogen
// Copyright (c) 2015 Ernesto Jiménez
func stripGopath(p string) string {
for _, gopath := range gopaths() {
p = strings.TrimPrefix(p, path.Join(gopath, "src")+"/")
}
return p
}
func gopaths() []string {
return strings.Split(os.Getenv("GOPATH"), string(filepath.ListSeparator))
}