Use x/tools/go/loader to load packages instead of custom importer.
This commit is contained in:
parent
c80e9a745b
commit
565e1649f5
2
main.go
2
main.go
@ -2,12 +2,12 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"os"
|
"os"
|
||||||
|
"errors"
|
||||||
|
|
||||||
"github.com/matryer/moq/pkg/moq"
|
"github.com/matryer/moq/pkg/moq"
|
||||||
)
|
)
|
||||||
|
@ -1,186 +0,0 @@
|
|||||||
package moq
|
|
||||||
|
|
||||||
// taken from https://github.com/ernesto-jimenez/gogen
|
|
||||||
// Copyright (c) 2015 Ernesto Jiménez
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"go/ast"
|
|
||||||
"go/importer"
|
|
||||||
"go/parser"
|
|
||||||
"go/token"
|
|
||||||
"go/types"
|
|
||||||
"io/ioutil"
|
|
||||||
"os"
|
|
||||||
"path"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
type customImporter struct {
|
|
||||||
source string
|
|
||||||
imported map[string]*types.Package
|
|
||||||
base types.Importer
|
|
||||||
skipTestFiles bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *customImporter) Import(path string) (*types.Package, error) {
|
|
||||||
var err error
|
|
||||||
if path == "" || path[0] == '.' {
|
|
||||||
path, err = filepath.Abs(filepath.Clean(path))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
path = stripGopath(path)
|
|
||||||
}
|
|
||||||
if pkg, ok := i.imported[path]; ok {
|
|
||||||
return pkg, nil
|
|
||||||
}
|
|
||||||
pkg, err := i.fsPkg(path)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
i.imported[path] = pkg
|
|
||||||
return pkg, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func gopathDir(source, pkg string) (string, error) {
|
|
||||||
// check vendor directory
|
|
||||||
vendorPath, found := vendorPath(source, pkg)
|
|
||||||
if found {
|
|
||||||
return vendorPath, nil
|
|
||||||
}
|
|
||||||
for _, gopath := range gopaths() {
|
|
||||||
absPath, err := filepath.Abs(path.Join(gopath, "src", pkg))
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
if dir, err := os.Stat(absPath); err == nil && dir.IsDir() {
|
|
||||||
return absPath, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return "", fmt.Errorf("%s not in $GOPATH or %s", pkg, path.Join(source, "vendor"))
|
|
||||||
}
|
|
||||||
|
|
||||||
func vendorPath(source, pkg string) (string, bool) {
|
|
||||||
for {
|
|
||||||
if isGopath(source) {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
var err error
|
|
||||||
source, err = filepath.Abs(source)
|
|
||||||
if err != nil {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
vendorPath, err := filepath.Abs(path.Join(source, "vendor", pkg))
|
|
||||||
if err != nil {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
if dir, err := os.Stat(vendorPath); err == nil && dir.IsDir() {
|
|
||||||
return vendorPath, true
|
|
||||||
}
|
|
||||||
source = filepath.Dir(source)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func removeGopath(p string) string {
|
|
||||||
for _, gopath := range gopaths() {
|
|
||||||
p = strings.Replace(p, path.Join(gopath, "src")+"/", "", 1)
|
|
||||||
}
|
|
||||||
return p
|
|
||||||
}
|
|
||||||
|
|
||||||
func gopaths() []string {
|
|
||||||
return strings.Split(os.Getenv("GOPATH"), string(filepath.ListSeparator))
|
|
||||||
}
|
|
||||||
|
|
||||||
func isGopath(path string) bool {
|
|
||||||
for _, p := range gopaths() {
|
|
||||||
if p == path {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *customImporter) fsPkg(pkg string) (*types.Package, error) {
|
|
||||||
dir, err := gopathDir(i.source, pkg)
|
|
||||||
if err != nil {
|
|
||||||
return importOrErr(i.base, pkg, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
dirFiles, err := ioutil.ReadDir(dir)
|
|
||||||
if err != nil {
|
|
||||||
return importOrErr(i.base, pkg, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
fset := token.NewFileSet()
|
|
||||||
var files []*ast.File
|
|
||||||
for _, fileInfo := range dirFiles {
|
|
||||||
if fileInfo.IsDir() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
n := fileInfo.Name()
|
|
||||||
if path.Ext(fileInfo.Name()) != ".go" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if i.skipTestFiles && strings.Contains(fileInfo.Name(), "_test.go") {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
file := path.Join(dir, n)
|
|
||||||
src, err := ioutil.ReadFile(file)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
f, err := parser.ParseFile(fset, file, src, 0)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
files = append(files, f)
|
|
||||||
}
|
|
||||||
conf := types.Config{
|
|
||||||
Importer: i,
|
|
||||||
}
|
|
||||||
p, err := conf.Check(pkg, fset, files, nil)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return importOrErr(i.base, pkg, err)
|
|
||||||
}
|
|
||||||
return p, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func importOrErr(base types.Importer, pkg string, err error) (*types.Package, error) {
|
|
||||||
p, impErr := base.Import(pkg)
|
|
||||||
if impErr != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return p, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// newImporter returns an importer that will try to import code from gopath before using go/importer.Default and skipping test files
|
|
||||||
func newImporter(source string) types.Importer {
|
|
||||||
return &customImporter{
|
|
||||||
source: source,
|
|
||||||
imported: make(map[string]*types.Package),
|
|
||||||
base: importer.Default(),
|
|
||||||
skipTestFiles: true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// // DefaultWithTestFiles same as Default but it parses test files too
|
|
||||||
// func DefaultWithTestFiles() types.Importer {
|
|
||||||
// return &customImporter{
|
|
||||||
// imported: make(map[string]*types.Package),
|
|
||||||
// base: importer.Default(),
|
|
||||||
// skipTestFiles: false,
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// stripGopath teks the directory to a package and remove the gopath to get the
|
|
||||||
// canonical package name.
|
|
||||||
func stripGopath(p string) string {
|
|
||||||
for _, gopath := range gopaths() {
|
|
||||||
p = strings.TrimPrefix(p, path.Join(gopath, "src")+"/")
|
|
||||||
}
|
|
||||||
return p
|
|
||||||
}
|
|
@ -11,8 +11,12 @@ import (
|
|||||||
"go/types"
|
"go/types"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
|
"path"
|
||||||
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"text/template"
|
"text/template"
|
||||||
|
|
||||||
|
"golang.org/x/tools/go/loader"
|
||||||
)
|
)
|
||||||
|
|
||||||
// This list comes from the golint codebase. Golint will complain about any of
|
// This list comes from the golint codebase. Golint will complain about any of
|
||||||
@ -75,11 +79,13 @@ func New(src, packageName string) (*Mocker, error) {
|
|||||||
noTestFiles := func(i os.FileInfo) bool {
|
noTestFiles := func(i os.FileInfo) bool {
|
||||||
return !strings.HasSuffix(i.Name(), "_test.go")
|
return !strings.HasSuffix(i.Name(), "_test.go")
|
||||||
}
|
}
|
||||||
|
|
||||||
pkgs, err := parser.ParseDir(fset, src, noTestFiles, parser.SpuriousErrors)
|
pkgs, err := parser.ParseDir(fset, src, noTestFiles, parser.SpuriousErrors)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if len(packageName) == 0 {
|
if len(packageName) == 0 {
|
||||||
|
|
||||||
for pkgName := range pkgs {
|
for pkgName := range pkgs {
|
||||||
if strings.Contains(pkgName, "_test") {
|
if strings.Contains(pkgName, "_test") {
|
||||||
continue
|
continue
|
||||||
@ -110,23 +116,20 @@ func (m *Mocker) Mock(w io.Writer, name ...string) error {
|
|||||||
if len(name) == 0 {
|
if len(name) == 0 {
|
||||||
return errors.New("must specify one interface")
|
return errors.New("must specify one interface")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pkgInfo, err := m.pkgInfoFromPath(m.src)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
doc := doc{
|
doc := doc{
|
||||||
PackageName: m.pkgName,
|
PackageName: m.pkgName,
|
||||||
Imports: moqImports,
|
Imports: moqImports,
|
||||||
}
|
}
|
||||||
|
|
||||||
mocksMethods := false
|
mocksMethods := false
|
||||||
for _, pkg := range m.pkgs {
|
|
||||||
i := 0
|
tpkg := pkgInfo.Pkg
|
||||||
files := make([]*ast.File, len(pkg.Files))
|
|
||||||
for _, file := range pkg.Files {
|
|
||||||
files[i] = file
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
conf := types.Config{Importer: newImporter(m.src)}
|
|
||||||
tpkg, err := conf.Check(m.src, m.fset, files, nil)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
for _, n := range name {
|
for _, n := range name {
|
||||||
iface := tpkg.Scope().Lookup(n)
|
iface := tpkg.Scope().Lookup(n)
|
||||||
if iface == nil {
|
if iface == nil {
|
||||||
@ -152,15 +155,17 @@ func (m *Mocker) Mock(w io.Writer, name ...string) error {
|
|||||||
}
|
}
|
||||||
doc.Objects = append(doc.Objects, obj)
|
doc.Objects = append(doc.Objects, obj)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
if mocksMethods {
|
if mocksMethods {
|
||||||
doc.Imports = append(doc.Imports, "sync")
|
doc.Imports = append(doc.Imports, "sync")
|
||||||
}
|
}
|
||||||
|
|
||||||
for pkgToImport := range m.imports {
|
for pkgToImport := range m.imports {
|
||||||
doc.Imports = append(doc.Imports, pkgToImport)
|
doc.Imports = append(doc.Imports, stripVendorPath(pkgToImport))
|
||||||
}
|
}
|
||||||
|
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
err := m.tmpl.Execute(&buf, doc)
|
err = m.tmpl.Execute(&buf, doc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -211,6 +216,32 @@ func (m *Mocker) extractArgs(sig *types.Signature, list *types.Tuple, nameFormat
|
|||||||
return params
|
return params
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (*Mocker) pkgInfoFromPath(src string) (*loader.PackageInfo, error) {
|
||||||
|
|
||||||
|
abs, err := filepath.Abs(src)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
pkgFull := stripGopath(abs)
|
||||||
|
|
||||||
|
conf := loader.Config{
|
||||||
|
ParserMode: parser.SpuriousErrors,
|
||||||
|
Cwd: src,
|
||||||
|
}
|
||||||
|
conf.Import(pkgFull)
|
||||||
|
lprog, err := conf.Load()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
pkgInfo := lprog.Package(pkgFull)
|
||||||
|
if pkgInfo == nil {
|
||||||
|
return nil, errors.New("package was nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
return pkgInfo, nil
|
||||||
|
}
|
||||||
|
|
||||||
type doc struct {
|
type doc struct {
|
||||||
PackageName string
|
PackageName string
|
||||||
Objects []obj
|
Objects []obj
|
||||||
@ -291,3 +322,31 @@ var templateFuncs = template.FuncMap{
|
|||||||
return strings.ToUpper(s[0:1]) + s[1:]
|
return strings.ToUpper(s[0:1]) + s[1:]
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// stripVendorPath strips the vendor dir prefix from a package path.
|
||||||
|
// For example we might encounter an absolute path like
|
||||||
|
// github.com/foo/bar/vendor/github.com/pkg/errors which is resolved
|
||||||
|
// to github.com/pkg/errors.
|
||||||
|
func stripVendorPath(p string) string {
|
||||||
|
parts := strings.Split(p, "/vendor/")
|
||||||
|
if len(parts) == 1 {
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
return strings.TrimLeft(path.Join(parts[1:]...), "/")
|
||||||
|
}
|
||||||
|
|
||||||
|
// stripGopath takes the directory to a package and remove the gopath to get the
|
||||||
|
// canonical package name.
|
||||||
|
//
|
||||||
|
// taken from https://github.com/ernesto-jimenez/gogen
|
||||||
|
// Copyright (c) 2015 Ernesto Jiménez
|
||||||
|
func stripGopath(p string) string {
|
||||||
|
for _, gopath := range gopaths() {
|
||||||
|
p = strings.TrimPrefix(p, path.Join(gopath, "src")+"/")
|
||||||
|
}
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func gopaths() []string {
|
||||||
|
return strings.Split(os.Getenv("GOPATH"), string(filepath.ListSeparator))
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user