added code that checks vendor folders
This commit is contained in:
parent
63b6f24493
commit
2abe0a029a
@ -18,6 +18,7 @@ import (
|
||||
)
|
||||
|
||||
type customImporter struct {
|
||||
source string
|
||||
imported map[string]*types.Package
|
||||
base types.Importer
|
||||
skipTestFiles bool
|
||||
@ -43,8 +44,13 @@ func (i *customImporter) Import(path string) (*types.Package, error) {
|
||||
return pkg, nil
|
||||
}
|
||||
|
||||
func gopathDir(pkg string) (string, error) {
|
||||
for _, gopath := range strings.Split(os.Getenv("GOPATH"), ":") {
|
||||
func gopathDir(source, pkg string) (string, error) {
|
||||
// check vendor directory
|
||||
vendorPath, found := vendorPath(source, pkg)
|
||||
if found {
|
||||
return vendorPath, nil
|
||||
}
|
||||
for _, gopath := range gopaths() {
|
||||
absPath, err := filepath.Abs(path.Join(gopath, "src", pkg))
|
||||
if err != nil {
|
||||
return "", err
|
||||
@ -53,18 +59,50 @@ func gopathDir(pkg string) (string, error) {
|
||||
return absPath, nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("%s not in $GOPATH", pkg)
|
||||
return "", fmt.Errorf("%s not in $GOPATH or %s", pkg, path.Join(source, "vendor"))
|
||||
}
|
||||
|
||||
func vendorPath(source, pkg string) (string, bool) {
|
||||
for {
|
||||
if isGopath(source) {
|
||||
return "", false
|
||||
}
|
||||
vendorPath, err := filepath.Abs(path.Join(source, "vendor", pkg))
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
if dir, err := os.Stat(vendorPath); err == nil && dir.IsDir() {
|
||||
return vendorPath, true
|
||||
}
|
||||
source = filepath.Dir(source)
|
||||
if source == "." {
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func removeGopath(p string) string {
|
||||
for _, gopath := range strings.Split(os.Getenv("GOPATH"), ":") {
|
||||
for _, gopath := range gopaths() {
|
||||
p = strings.Replace(p, path.Join(gopath, "src")+"/", "", 1)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func gopaths() []string {
|
||||
return strings.Split(os.Getenv("GOPATH"), string(filepath.ListSeparator))
|
||||
}
|
||||
|
||||
func isGopath(path string) bool {
|
||||
for _, p := range gopaths() {
|
||||
if p == path {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (i *customImporter) fsPkg(pkg string) (*types.Package, error) {
|
||||
dir, err := gopathDir(pkg)
|
||||
dir, err := gopathDir(i.source, pkg)
|
||||
if err != nil {
|
||||
return importOrErr(i.base, pkg, err)
|
||||
}
|
||||
@ -118,8 +156,9 @@ func importOrErr(base types.Importer, pkg string, err error) (*types.Package, er
|
||||
}
|
||||
|
||||
// newImporter returns an importer that will try to import code from gopath before using go/importer.Default and skipping test files
|
||||
func newImporter() types.Importer {
|
||||
func newImporter(source string) types.Importer {
|
||||
return &customImporter{
|
||||
source: source,
|
||||
imported: make(map[string]*types.Package),
|
||||
base: importer.Default(),
|
||||
skipTestFiles: true,
|
||||
@ -138,8 +177,8 @@ func newImporter() types.Importer {
|
||||
// stripGopath teks the directory to a package and remove the gopath to get the
|
||||
// cannonical package name
|
||||
func stripGopath(p string) string {
|
||||
for _, gopath := range strings.Split(os.Getenv("GOPATH"), ":") {
|
||||
p = strings.Replace(p, path.Join(gopath, "src")+"/", "", 1)
|
||||
for _, gopath := range gopaths() {
|
||||
p = strings.TrimPrefix(p, path.Join(gopath, "src")+"/")
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
@ -76,7 +76,7 @@ func (m *Mocker) Mock(w io.Writer, name ...string) error {
|
||||
files[i] = file
|
||||
i++
|
||||
}
|
||||
conf := types.Config{Importer: newImporter()}
|
||||
conf := types.Config{Importer: newImporter(m.src)}
|
||||
tpkg, err := conf.Check(m.src, m.fset, files, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -2,7 +2,6 @@ package moq
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"log"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
@ -183,5 +182,14 @@ func TestVendoredPackages(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Errorf("mock error: %s", err)
|
||||
}
|
||||
log.Println(buf.String())
|
||||
s := buf.String()
|
||||
// assertions of things that should be mentioned
|
||||
var strs = []string{
|
||||
`"github.com/matryer/somerepo"`,
|
||||
}
|
||||
for _, str := range strs {
|
||||
if !strings.Contains(s, str) {
|
||||
t.Errorf("expected but missing: \"%s\"", str)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user