// Copyright (c) 2021 Gitpod GmbH. All rights reserved. // Licensed under the GNU Affero General Public License (AGPL). // See License-AGPL.txt in the project root for license information. package sshproxy import ( "context" "crypto/subtle" "net" "regexp" "strings" "time" "github.com/gitpod-io/gitpod/common-go/analytics" "github.com/gitpod-io/gitpod/common-go/log" supervisor "github.com/gitpod-io/gitpod/supervisor/api" tracker "github.com/gitpod-io/gitpod/ws-proxy/pkg/analytics" p "github.com/gitpod-io/gitpod/ws-proxy/pkg/proxy" "github.com/prometheus/client_golang/prometheus" "golang.org/x/crypto/ssh" "golang.org/x/xerrors" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "sigs.k8s.io/controller-runtime/pkg/metrics" ) const GitpodUsername = "gitpod" // This is copy from proxy/workspacerouter.go const workspaceIDRegex = "([0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}|[0-9a-z]{2,16}-[0-9a-z]{2,16}-[0-9a-z]{8,11})" var ( SSHConnectionCount = prometheus.NewGauge(prometheus.GaugeOpts{ Name: "gitpod_ws_proxy_ssh_connection_count", Help: "Current number of SSH connection", }) SSHAttemptTotal = prometheus.NewCounterVec(prometheus.CounterOpts{ Name: "gitpod_ws_proxy_ssh_attempt_total", Help: "Total number of SSH attempt", }, []string{"status", "error_type"}) ) var ( ErrWorkspaceNotFound = NewSSHError("WS_NOTFOUND", "not found workspace") ErrWorkspaceIDInvalid = NewSSHError("WS_ID_INVALID", "workspace id invalid") ErrAuthFailed = NewSSHError("AUTH_FAILED", "auth failed") ErrUsernameFormat = NewSSHError("USER_FORMAT", "username format is not correct") ErrMissPrivateKey = NewSSHError("MISS_KEY", "missing privateKey") ErrConnFailed = NewSSHError("CONN_FAILED", "cannot to connect with workspace") ErrCreateSSHKey = NewSSHError("CREATE_KEY_FAILED", "cannot create private pair in workspace") ) type SSHError struct { shortName string description string } func (e SSHError) Error() string { return e.description } func (e SSHError) ShortName() string { return e.shortName } func NewSSHError(shortName string, description string) SSHError { return SSHError{shortName: shortName, description: description} } type Session struct { Conn *ssh.ServerConn WorkspaceID string InstanceID string PublicKey ssh.PublicKey WorkspacePrivateKey ssh.Signer } type Server struct { Heartbeater Heartbeat sshConfig *ssh.ServerConfig workspaceInfoProvider p.WorkspaceInfoProvider } func init() { metrics.Registry.MustRegister( SSHConnectionCount, SSHAttemptTotal, ) } // New creates a new SSH proxy server func New(signers []ssh.Signer, workspaceInfoProvider p.WorkspaceInfoProvider, heartbeat Heartbeat) *Server { server := &Server{ workspaceInfoProvider: workspaceInfoProvider, Heartbeater: &noHeartbeat{}, } if heartbeat != nil { server.Heartbeater = heartbeat } server.sshConfig = &ssh.ServerConfig{ ServerVersion: "SSH-2.0-GITPOD-GATEWAY", PasswordCallback: func(conn ssh.ConnMetadata, password []byte) (perm *ssh.Permissions, err error) { workspaceId, ownerToken := conn.User(), string(password) wsInfo, err := server.GetWorkspaceInfo(workspaceId) defer func() { server.TrackSSHConnection(wsInfo, "auth", err) }() if err != nil { args := strings.Split(conn.User(), "#") if len(args) != 2 { return } workspaceId, ownerToken = args[0], args[1] wsInfo, err = server.GetWorkspaceInfo(workspaceId) if err != nil { return nil, ErrWorkspaceNotFound } if wsInfo.Auth.OwnerToken == ownerToken { return nil, ErrMissPrivateKey } return nil, ErrAuthFailed } if wsInfo.Auth.OwnerToken != ownerToken { return nil, ErrAuthFailed } return &ssh.Permissions{ Extensions: map[string]string{ "workspaceId": workspaceId, }, }, nil }, PublicKeyCallback: func(conn ssh.ConnMetadata, pk ssh.PublicKey) (perm *ssh.Permissions, err error) { args := strings.Split(conn.User(), "#") workspaceId := args[0] wsInfo, err := server.GetWorkspaceInfo(workspaceId) if err != nil { return nil, err } defer func() { server.TrackSSHConnection(wsInfo, "auth", err) }() ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() ok, _ := server.VerifyPublicKey(ctx, wsInfo, pk) if ok { return &ssh.Permissions{ Extensions: map[string]string{ "workspaceId": workspaceId, }, }, nil } // workspaceId#ownerToken if len(args) != 2 || wsInfo.Auth.OwnerToken != args[1] { return nil, ErrAuthFailed } return &ssh.Permissions{ Extensions: map[string]string{ "workspaceId": workspaceId, }, }, nil }, } for _, s := range signers { server.sshConfig.AddHostKey(s) } return server } func ReportSSHAttemptMetrics(err error) { if err == nil { SSHAttemptTotal.WithLabelValues("success", "").Inc() return } errorType := "OTHERS" if serverAuthErr, ok := err.(*ssh.ServerAuthError); ok && len(serverAuthErr.Errors) > 0 { if authErr, ok := serverAuthErr.Errors[len(serverAuthErr.Errors)-1].(SSHError); ok { errorType = authErr.ShortName() } } SSHAttemptTotal.WithLabelValues("failed", errorType).Inc() } func (s *Server) RequestForward(reqs <-chan *ssh.Request, targetConn ssh.Conn) { for req := range reqs { result, payload, err := targetConn.SendRequest(req.Type, req.WantReply, req.Payload) if err != nil { continue } _ = req.Reply(result, payload) } } func (s *Server) HandleConn(c net.Conn) { clientConn, clientChans, clientReqs, err := ssh.NewServerConn(c, s.sshConfig) if err != nil { c.Close() ReportSSHAttemptMetrics(err) return } defer clientConn.Close() if clientConn.Permissions == nil || clientConn.Permissions.Extensions == nil || clientConn.Permissions.Extensions["workspaceId"] == "" { return } workspaceId := clientConn.Permissions.Extensions["workspaceId"] wsInfo := s.workspaceInfoProvider.WorkspaceInfo(workspaceId) if wsInfo == nil { ReportSSHAttemptMetrics(ErrWorkspaceNotFound) return } ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) key, err := s.GetWorkspaceSSHKey(ctx, wsInfo.IPAddress) if err != nil { cancel() s.TrackSSHConnection(wsInfo, "connect", ErrCreateSSHKey) ReportSSHAttemptMetrics(ErrCreateSSHKey) log.WithField("instanceId", wsInfo.InstanceID).WithError(err).Error("failed to create private pair in workspace") return } cancel() session := &Session{ Conn: clientConn, WorkspaceID: workspaceId, InstanceID: wsInfo.InstanceID, WorkspacePrivateKey: key, } remoteAddr := wsInfo.IPAddress + ":23001" conn, err := net.Dial("tcp", remoteAddr) if err != nil { s.TrackSSHConnection(wsInfo, "connect", ErrConnFailed) ReportSSHAttemptMetrics(ErrConnFailed) log.WithField("instanceId", wsInfo.InstanceID).WithField("workspaceIP", wsInfo.IPAddress).WithError(err).Error("dail failed") return } defer conn.Close() workspaceConn, workspaceChans, workspaceReqs, err := ssh.NewClientConn(conn, remoteAddr, &ssh.ClientConfig{ HostKeyCallback: ssh.InsecureIgnoreHostKey(), User: GitpodUsername, Auth: []ssh.AuthMethod{ ssh.PublicKeysCallback(func() (signers []ssh.Signer, err error) { return []ssh.Signer{key}, nil }), }, Timeout: 10 * time.Second, }) if err != nil { s.TrackSSHConnection(wsInfo, "connect", ErrConnFailed) ReportSSHAttemptMetrics(ErrConnFailed) log.WithField("instanceId", wsInfo.InstanceID).WithField("workspaceIP", wsInfo.IPAddress).WithError(err).Error("connect failed") return } s.Heartbeater.SendHeartbeat(wsInfo.InstanceID, false) ctx, cancel = context.WithCancel(context.Background()) s.TrackSSHConnection(wsInfo, "connect", nil) SSHConnectionCount.Inc() ReportSSHAttemptMetrics(nil) forwardRequests := func(reqs <-chan *ssh.Request, targetConn ssh.Conn) { for req := range reqs { result, payload, err := targetConn.SendRequest(req.Type, req.WantReply, req.Payload) if err != nil { continue } _ = req.Reply(result, payload) } } // client -> workspace global request forward go forwardRequests(clientReqs, workspaceConn) // workspce -> client global request forward go forwardRequests(workspaceReqs, clientConn) go func() { for newChannel := range workspaceChans { go s.ChannelForward(ctx, session, clientConn, newChannel) } }() go func() { for newChannel := range clientChans { go s.ChannelForward(ctx, session, workspaceConn, newChannel) } }() go func() { clientConn.Wait() cancel() }() go func() { workspaceConn.Wait() cancel() }() <-ctx.Done() SSHConnectionCount.Dec() workspaceConn.Close() clientConn.Close() cancel() } func (s *Server) GetWorkspaceInfo(workspaceId string) (*p.WorkspaceInfo, error) { wsInfo := s.workspaceInfoProvider.WorkspaceInfo(workspaceId) if wsInfo == nil { if matched, _ := regexp.Match(workspaceIDRegex, []byte(workspaceId)); matched { return nil, ErrWorkspaceNotFound } return nil, ErrWorkspaceIDInvalid } return wsInfo, nil } func (s *Server) TrackSSHConnection(wsInfo *p.WorkspaceInfo, phase string, err error) { // if we didn't find an associated user, we don't want to track if wsInfo == nil { return } propertics := make(map[string]interface{}) propertics["workspaceId"] = wsInfo.WorkspaceID propertics["instanceId"] = wsInfo.InstanceID propertics["state"] = "success" propertics["phase"] = phase if err != nil { propertics["state"] = "failed" propertics["cause"] = err.Error() } tracker.Track(analytics.TrackMessage{ Identity: analytics.Identity{UserID: wsInfo.OwnerUserId}, Event: "ssh_connection", Properties: propertics, }) } func (s *Server) VerifyPublicKey(ctx context.Context, wsInfo *p.WorkspaceInfo, pk ssh.PublicKey) (bool, error) { for _, keyStr := range wsInfo.SSHPublicKeys { key, _, _, _, err := ssh.ParseAuthorizedKey([]byte(keyStr)) if err != nil { continue } keyData := key.Marshal() pkd := pk.Marshal() if len(keyData) == len(pkd) && subtle.ConstantTimeCompare(keyData, pkd) == 1 { return true, nil } } return false, nil } func (s *Server) GetWorkspaceSSHKey(ctx context.Context, workspaceIP string) (ssh.Signer, error) { supervisorConn, err := grpc.Dial(workspaceIP+":22999", grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return nil, xerrors.Errorf("failed connecting to supervisor: %w", err) } defer supervisorConn.Close() keyInfo, err := supervisor.NewControlServiceClient(supervisorConn).CreateSSHKeyPair(ctx, &supervisor.CreateSSHKeyPairRequest{}) if err != nil { return nil, xerrors.Errorf("failed getting ssh key pair info from supervisor: %w", err) } key, err := ssh.ParsePrivateKey([]byte(keyInfo.PrivateKey)) if err != nil { return nil, xerrors.Errorf("failed parse private key: %w", err) } return key, nil } func (s *Server) Serve(l net.Listener) error { for { conn, err := l.Accept() if err != nil { return err } go s.HandleConn(conn) } }