added support for stream type npipe (Windows named pipe)
Signed-off-by: Sam Alba <samalba@users.noreply.github.com>
This commit is contained in:
@@ -7,6 +7,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Microsoft/go-winio"
|
||||
"github.com/moby/buildkit/session"
|
||||
"github.com/moby/buildkit/session/sshforward"
|
||||
"google.golang.org/grpc"
|
||||
@@ -14,7 +15,8 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
unixPrefix = "unix="
|
||||
unixPrefix = "unix="
|
||||
npipePrefix = "npipe="
|
||||
)
|
||||
|
||||
type SocketProvider struct {
|
||||
@@ -33,12 +35,26 @@ func (sp *SocketProvider) CheckAgent(ctx context.Context, req *sshforward.CheckA
|
||||
if req.ID != "" {
|
||||
id = req.ID
|
||||
}
|
||||
if !strings.HasPrefix(id, unixPrefix) {
|
||||
if !strings.HasPrefix(id, unixPrefix) && !strings.HasPrefix(id, npipePrefix) {
|
||||
return &sshforward.CheckAgentResponse{}, fmt.Errorf("invalid socket forward key %s", id)
|
||||
}
|
||||
return &sshforward.CheckAgentResponse{}, nil
|
||||
}
|
||||
|
||||
func dialStream(id string) (net.Conn, error) {
|
||||
switch {
|
||||
case strings.HasPrefix(id, unixPrefix):
|
||||
id = strings.TrimPrefix(id, unixPrefix)
|
||||
return net.DialTimeout("unix", id, time.Second)
|
||||
case strings.HasPrefix(id, npipePrefix):
|
||||
id = strings.TrimPrefix(id, npipePrefix)
|
||||
dur := time.Second
|
||||
return winio.DialPipe(id, &dur)
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid socket forward key %s", id)
|
||||
}
|
||||
}
|
||||
|
||||
func (sp *SocketProvider) ForwardAgent(stream sshforward.SSH_ForwardAgentServer) error {
|
||||
id := sshforward.DefaultID
|
||||
|
||||
@@ -48,13 +64,7 @@ func (sp *SocketProvider) ForwardAgent(stream sshforward.SSH_ForwardAgentServer)
|
||||
id = v[0]
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(id, unixPrefix) {
|
||||
return fmt.Errorf("invalid socket forward key %s", id)
|
||||
}
|
||||
|
||||
id = strings.TrimPrefix(id, unixPrefix)
|
||||
|
||||
conn, err := net.DialTimeout("unix", id, time.Second)
|
||||
conn, err := dialStream(id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to %s: %w", id, err)
|
||||
}
|
||||
|
Reference in New Issue
Block a user