2022-12-08 13:05:19 -03:00

278 lines
7.3 KiB
Go

// Copyright (c) 2021 Gitpod GmbH. All rights reserved.
// Licensed under the GNU Affero General Public License (AGPL).
// See License.AGPL.txt in the project root for license information.
package ports
import (
"context"
"fmt"
"io"
"net"
"sort"
"strconv"
"sync"
"golang.org/x/xerrors"
"github.com/gitpod-io/gitpod/supervisor/api"
)
type PortTunnelDescription struct {
LocalPort uint32
TargetPort uint32
Visibility api.TunnelVisiblity
}
type PortTunnelState struct {
Desc PortTunnelDescription
Clients map[string]uint32
}
type PortTunnel struct {
State PortTunnelState
Conns map[string]map[net.Conn]struct{}
}
type TunnelOptions struct {
SkipIfExists bool
}
// TunneledPortsInterface observes the tunneled ports.
type TunneledPortsInterface interface {
// Observe starts observing the tunneled ports until the context is canceled.
// The list of tunneled ports is always the complete picture, i.e. if a single port changes,
// the whole list is returned.
// When the observer stops operating (because the context as canceled or an irrecoverable
// error occurred), the observer will close both channels.
Observe(ctx context.Context) (<-chan []PortTunnelState, <-chan error)
// Tunnel notifies clients to install listeners on remote machines.
// After that such clients should call EstablishTunnel to forward incoming connections.
Tunnel(ctx context.Context, options *TunnelOptions, descs ...*PortTunnelDescription) ([]uint32, error)
// CloseTunnel closes tunnels.
CloseTunnel(ctx context.Context, localPorts ...uint32) ([]uint32, error)
// EstablishTunnel actually establishes the tunnel for an incoming connection on a remote machine.
EstablishTunnel(ctx context.Context, clientID string, localPort uint32, targetPort uint32) (net.Conn, error)
}
// TunneledPortsService observes the tunneled ports.
type TunneledPortsService struct {
mu *sync.RWMutex
cond *sync.Cond
tunnels map[uint32]*PortTunnel
}
// NewTunneledPortsService creates a new instance.
func NewTunneledPortsService(debugEnable bool) *TunneledPortsService {
var mu sync.RWMutex
return &TunneledPortsService{
mu: &mu,
cond: sync.NewCond(&mu),
tunnels: make(map[uint32]*PortTunnel),
}
}
type tunnelConn struct {
net.Conn
once sync.Once
closeErr error
onDidClose func()
}
func (c *tunnelConn) Close() error {
c.once.Do(func() {
c.closeErr = c.Conn.Close()
c.onDidClose()
})
return c.closeErr
}
// Observe starts observing the tunneled ports until the context is canceled.
func (p *TunneledPortsService) Observe(ctx context.Context) (<-chan []PortTunnelState, <-chan error) {
var (
errchan = make(chan error, 1)
reschan = make(chan []PortTunnelState)
)
go func() {
defer close(errchan)
defer close(reschan)
p.cond.L.Lock()
defer p.cond.L.Unlock()
for {
var i int
res := make([]PortTunnelState, len(p.tunnels))
for _, port := range p.tunnels {
res[i] = port.State
i++
}
reschan <- res
p.cond.Wait()
if ctx.Err() != nil {
return
}
}
}()
return reschan, errchan
}
func (desc *PortTunnelDescription) validate() (err error) {
if desc.LocalPort <= 0 || desc.LocalPort > 0xFFFF {
return xerrors.Errorf("bad local port: %d", desc.LocalPort)
}
if desc.TargetPort > 0xFFFF {
return xerrors.Errorf("bad target port: %d", desc.TargetPort)
}
return nil
}
// Tunnel opens new tunnels.
func (p *TunneledPortsService) Tunnel(ctx context.Context, options *TunnelOptions, descs ...*PortTunnelDescription) (tunneled []uint32, err error) {
var shouldNotify bool
p.cond.L.Lock()
defer p.cond.L.Unlock()
for _, desc := range descs {
descErr := desc.validate()
if descErr != nil {
if err == nil {
err = descErr
} else {
err = xerrors.Errorf("%s\n%s", err, descErr)
}
continue
}
tunnel, tunnelExists := p.tunnels[desc.LocalPort]
if !tunnelExists {
tunnel = &PortTunnel{
State: PortTunnelState{
Clients: make(map[string]uint32),
},
Conns: make(map[string]map[net.Conn]struct{}),
}
p.tunnels[desc.LocalPort] = tunnel
} else if options.SkipIfExists {
continue
}
tunnel.State.Desc = *desc
shouldNotify = true
tunneled = append(tunneled, desc.LocalPort)
}
if shouldNotify {
p.cond.Broadcast()
}
return tunneled, err
}
// CloseTunnel closes tunnels.
func (p *TunneledPortsService) CloseTunnel(ctx context.Context, localPorts ...uint32) (closedPorts []uint32, err error) {
var closed []*PortTunnel
p.cond.L.Lock()
for _, localPort := range localPorts {
tunnel, existsTunnel := p.tunnels[localPort]
if !existsTunnel {
continue
}
delete(p.tunnels, localPort)
closed = append(closed, tunnel)
closedPorts = append(closedPorts, localPort)
}
if len(closed) > 0 {
p.cond.Broadcast()
}
p.cond.L.Unlock()
for _, tunnel := range closed {
for _, conns := range tunnel.Conns {
for conn := range conns {
closeErr := conn.Close()
if closeErr == nil {
continue
}
if err == nil {
err = closeErr
} else {
err = xerrors.Errorf("%s\n%s", err, closeErr)
}
}
}
}
return closedPorts, err
}
// EstablishTunnel actually establishes the tunnel.
func (p *TunneledPortsService) EstablishTunnel(ctx context.Context, clientID string, localPort uint32, targetPort uint32) (net.Conn, error) {
p.cond.L.Lock()
defer p.cond.L.Unlock()
tunnel, tunnelExists := p.tunnels[localPort]
if tunnelExists {
expectedTargetPort, clientExists := tunnel.State.Clients[clientID]
if clientExists && expectedTargetPort != targetPort {
return nil, xerrors.Errorf("client '%s': %d:%d is already tunneling", clientID, localPort, targetPort)
}
} else {
return nil, xerrors.Errorf("client '%s': '%d' tunnel does not exist", clientID, localPort)
}
addr := net.JoinHostPort("localhost", strconv.FormatInt(int64(localPort), 10))
conn, err := net.Dial("tcp", addr)
if err != nil {
return nil, err
}
var result net.Conn
result = &tunnelConn{
Conn: conn,
onDidClose: func() {
p.cond.L.Lock()
defer p.cond.L.Unlock()
_, existsTunnel := p.tunnels[localPort]
if !existsTunnel {
return
}
delete(tunnel.Conns[clientID], result)
if len(tunnel.Conns[clientID]) == 0 {
delete(tunnel.State.Clients, clientID)
}
p.cond.Broadcast()
},
}
if tunnel.Conns[clientID] == nil {
tunnel.Conns[clientID] = make(map[net.Conn]struct{})
}
tunnel.Conns[clientID][result] = struct{}{}
tunnel.State.Clients[clientID] = targetPort
p.cond.Broadcast()
return result, nil
}
// Snapshot writes a snapshot to w.
func (p *TunneledPortsService) Snapshot(w io.Writer) {
p.mu.RLock()
defer p.mu.RUnlock()
localPorts := make([]uint32, 0, len(p.tunnels))
for k := range p.tunnels {
localPorts = append(localPorts, k)
}
sort.Slice(localPorts, func(i, j int) bool { return localPorts[i] < localPorts[j] })
for _, localPort := range localPorts {
tunnel := p.tunnels[localPort]
fmt.Fprintf(w, "Local Port: %d\n", tunnel.State.Desc.LocalPort)
fmt.Fprintf(w, "Target Port: %d\n", tunnel.State.Desc.TargetPort)
visibilty := api.TunnelVisiblity_name[int32(tunnel.State.Desc.Visibility)]
fmt.Fprintf(w, "Visibility: %s\n", visibilty)
for clientID, remotePort := range tunnel.State.Clients {
fmt.Fprintf(w, "Client: %s\n", clientID)
fmt.Fprintf(w, " Remote Port: %d\n", remotePort)
fmt.Fprintf(w, " Tunnel Count: %d\n", len(tunnel.Conns[clientID]))
}
fmt.Fprintf(w, "\n")
}
}