191 lines
4.7 KiB
Go
191 lines
4.7 KiB
Go
package registry
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"go/types"
|
|
"path/filepath"
|
|
"sort"
|
|
"strings"
|
|
|
|
"golang.org/x/tools/go/packages"
|
|
)
|
|
|
|
// Registry encapsulates types information for the source and mock
|
|
// destination package. For the mock package, it tracks the list of
|
|
// imports and ensures there are no conflicts in the imported package
|
|
// qualifiers.
|
|
type Registry struct {
|
|
srcPkg *packages.Package
|
|
moqPkgPath string
|
|
aliases map[string]string
|
|
imports map[string]*Package
|
|
}
|
|
|
|
// New loads the source package info and returns a new instance of
|
|
// Registry.
|
|
func New(srcDir, moqPkg string) (*Registry, error) {
|
|
srcPkg, err := pkgInfoFromPath(
|
|
srcDir, packages.NeedName|packages.NeedSyntax|packages.NeedTypes|packages.NeedTypesInfo,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("couldn't load source package: %s", err)
|
|
}
|
|
|
|
return &Registry{
|
|
srcPkg: srcPkg,
|
|
moqPkgPath: findPkgPath(moqPkg, srcPkg),
|
|
aliases: parseImportsAliases(srcPkg),
|
|
imports: make(map[string]*Package),
|
|
}, nil
|
|
}
|
|
|
|
// SrcPkg returns the types info for the source package.
|
|
func (r Registry) SrcPkg() *types.Package {
|
|
return r.srcPkg.Types
|
|
}
|
|
|
|
// SrcPkgName returns the name of the source package.
|
|
func (r Registry) SrcPkgName() string {
|
|
return r.srcPkg.Name
|
|
}
|
|
|
|
// LookupInterface returns the underlying interface definition of the
|
|
// given interface name.
|
|
func (r Registry) LookupInterface(name string) (*types.Interface, error) {
|
|
obj := r.SrcPkg().Scope().Lookup(name)
|
|
if obj == nil {
|
|
return nil, fmt.Errorf("interface not found: %s", name)
|
|
}
|
|
|
|
if !types.IsInterface(obj.Type()) {
|
|
return nil, fmt.Errorf("%s (%s) is not an interface", name, obj.Type())
|
|
}
|
|
|
|
return obj.Type().Underlying().(*types.Interface).Complete(), nil
|
|
}
|
|
|
|
// MethodScope returns a new MethodScope.
|
|
func (r *Registry) MethodScope() *MethodScope {
|
|
return &MethodScope{
|
|
registry: r,
|
|
moqPkgPath: r.moqPkgPath,
|
|
conflicted: map[string]bool{},
|
|
}
|
|
}
|
|
|
|
// AddImport adds the given package to the set of imports. It generates a
|
|
// suitable alias if there are any conflicts with previously imported
|
|
// packages.
|
|
func (r *Registry) AddImport(pkg *types.Package) *Package {
|
|
path := stripVendorPath(pkg.Path())
|
|
if path == r.moqPkgPath {
|
|
return nil
|
|
}
|
|
|
|
if imprt, ok := r.imports[path]; ok {
|
|
return imprt
|
|
}
|
|
|
|
imprt := Package{pkg: pkg, Alias: r.aliases[path]}
|
|
|
|
if conflict, ok := r.searchImport(imprt.Qualifier()); ok {
|
|
resolveImportConflict(&imprt, conflict, 0)
|
|
}
|
|
|
|
r.imports[path] = &imprt
|
|
return &imprt
|
|
}
|
|
|
|
// Imports returns the list of imported packages. The list is sorted by
|
|
// path.
|
|
func (r Registry) Imports() []*Package {
|
|
imports := make([]*Package, 0, len(r.imports))
|
|
for _, imprt := range r.imports {
|
|
imports = append(imports, imprt)
|
|
}
|
|
sort.Slice(imports, func(i, j int) bool {
|
|
return imports[i].Path() < imports[j].Path()
|
|
})
|
|
return imports
|
|
}
|
|
|
|
func (r Registry) searchImport(name string) (*Package, bool) {
|
|
for _, imprt := range r.imports {
|
|
if imprt.Qualifier() == name {
|
|
return imprt, true
|
|
}
|
|
}
|
|
|
|
return nil, false
|
|
}
|
|
|
|
func pkgInfoFromPath(srcDir string, mode packages.LoadMode) (*packages.Package, error) {
|
|
pkgs, err := packages.Load(&packages.Config{
|
|
Mode: mode,
|
|
Dir: srcDir,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(pkgs) == 0 {
|
|
return nil, errors.New("package not found")
|
|
}
|
|
if len(pkgs) > 1 {
|
|
return nil, errors.New("found more than one package")
|
|
}
|
|
if errs := pkgs[0].Errors; len(errs) != 0 {
|
|
if len(errs) == 1 {
|
|
return nil, errs[0]
|
|
}
|
|
return nil, fmt.Errorf("%s (and %d more errors)", errs[0], len(errs)-1)
|
|
}
|
|
return pkgs[0], nil
|
|
}
|
|
|
|
func findPkgPath(pkgInputVal string, srcPkg *packages.Package) string {
|
|
if pkgInputVal == "" {
|
|
return srcPkg.PkgPath
|
|
}
|
|
if pkgInDir(srcPkg.PkgPath, pkgInputVal) {
|
|
return srcPkg.PkgPath
|
|
}
|
|
subdirectoryPath := filepath.Join(srcPkg.PkgPath, pkgInputVal)
|
|
if pkgInDir(subdirectoryPath, pkgInputVal) {
|
|
return subdirectoryPath
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func pkgInDir(pkgName, dir string) bool {
|
|
currentPkg, err := pkgInfoFromPath(dir, packages.NeedName)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
return currentPkg.Name == pkgName || currentPkg.Name+"_test" == pkgName
|
|
}
|
|
|
|
func parseImportsAliases(pkg *packages.Package) map[string]string {
|
|
aliases := make(map[string]string)
|
|
for _, syntax := range pkg.Syntax {
|
|
for _, imprt := range syntax.Imports {
|
|
if imprt.Name != nil && imprt.Name.Name != "." && imprt.Name.Name != "_" {
|
|
aliases[strings.Trim(imprt.Path.Value, `"`)] = imprt.Name.Name
|
|
}
|
|
}
|
|
}
|
|
return aliases
|
|
}
|
|
|
|
// resolveImportConflict generates and assigns a unique alias for
|
|
// packages with conflicting qualifiers.
|
|
func resolveImportConflict(a, b *Package, lvl int) {
|
|
u1, u2 := a.uniqueName(lvl), b.uniqueName(lvl)
|
|
if u1 != u2 {
|
|
a.Alias, b.Alias = u1, u2
|
|
return
|
|
}
|
|
|
|
resolveImportConflict(a, b, lvl+1)
|
|
}
|