mirror of
https://github.com/gitpod-io/gitpod.git
synced 2025-12-08 17:36:30 +00:00
278 lines
7.3 KiB
Go
278 lines
7.3 KiB
Go
// Copyright (c) 2021 Gitpod GmbH. All rights reserved.
|
|
// Licensed under the GNU Affero General Public License (AGPL).
|
|
// See License-AGPL.txt in the project root for license information.
|
|
|
|
package ports
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"sort"
|
|
"strconv"
|
|
"sync"
|
|
|
|
"golang.org/x/xerrors"
|
|
|
|
"github.com/gitpod-io/gitpod/supervisor/api"
|
|
)
|
|
|
|
type PortTunnelDescription struct {
|
|
LocalPort uint32
|
|
TargetPort uint32
|
|
Visibility api.TunnelVisiblity
|
|
}
|
|
|
|
type PortTunnelState struct {
|
|
Desc PortTunnelDescription
|
|
Clients map[string]uint32
|
|
}
|
|
|
|
type PortTunnel struct {
|
|
State PortTunnelState
|
|
Conns map[string]map[net.Conn]struct{}
|
|
}
|
|
|
|
type TunnelOptions struct {
|
|
SkipIfExists bool
|
|
}
|
|
|
|
// TunneledPortsInterface observes the tunneled ports.
|
|
type TunneledPortsInterface interface {
|
|
// Observe starts observing the tunneled ports until the context is canceled.
|
|
// The list of tunneled ports is always the complete picture, i.e. if a single port changes,
|
|
// the whole list is returned.
|
|
// When the observer stops operating (because the context as canceled or an irrecoverable
|
|
// error occurred), the observer will close both channels.
|
|
Observe(ctx context.Context) (<-chan []PortTunnelState, <-chan error)
|
|
|
|
// Tunnel notifies clients to install listeners on remote machines.
|
|
// After that such clients should call EstablishTunnel to forward incoming connections.
|
|
Tunnel(ctx context.Context, options *TunnelOptions, descs ...*PortTunnelDescription) ([]uint32, error)
|
|
|
|
// CloseTunnel closes tunnels.
|
|
CloseTunnel(ctx context.Context, localPorts ...uint32) ([]uint32, error)
|
|
|
|
// EstablishTunnel actually establishes the tunnel for an incoming connection on a remote machine.
|
|
EstablishTunnel(ctx context.Context, clientID string, localPort uint32, targetPort uint32) (net.Conn, error)
|
|
}
|
|
|
|
// TunneledPortsService observes the tunneled ports.
|
|
type TunneledPortsService struct {
|
|
mu *sync.RWMutex
|
|
cond *sync.Cond
|
|
tunnels map[uint32]*PortTunnel
|
|
}
|
|
|
|
// NewTunneledPortsService creates a new instance.
|
|
func NewTunneledPortsService(debugEnable bool) *TunneledPortsService {
|
|
var mu sync.RWMutex
|
|
return &TunneledPortsService{
|
|
mu: &mu,
|
|
cond: sync.NewCond(&mu),
|
|
tunnels: make(map[uint32]*PortTunnel),
|
|
}
|
|
}
|
|
|
|
type tunnelConn struct {
|
|
net.Conn
|
|
once sync.Once
|
|
closeErr error
|
|
onDidClose func()
|
|
}
|
|
|
|
func (c *tunnelConn) Close() error {
|
|
c.once.Do(func() {
|
|
c.closeErr = c.Conn.Close()
|
|
c.onDidClose()
|
|
})
|
|
return c.closeErr
|
|
}
|
|
|
|
// Observe starts observing the tunneled ports until the context is canceled.
|
|
func (p *TunneledPortsService) Observe(ctx context.Context) (<-chan []PortTunnelState, <-chan error) {
|
|
var (
|
|
errchan = make(chan error, 1)
|
|
reschan = make(chan []PortTunnelState)
|
|
)
|
|
|
|
go func() {
|
|
defer close(errchan)
|
|
defer close(reschan)
|
|
|
|
p.cond.L.Lock()
|
|
defer p.cond.L.Unlock()
|
|
for {
|
|
var i int
|
|
res := make([]PortTunnelState, len(p.tunnels))
|
|
for _, port := range p.tunnels {
|
|
res[i] = port.State
|
|
i++
|
|
}
|
|
reschan <- res
|
|
|
|
p.cond.Wait()
|
|
if ctx.Err() != nil {
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
return reschan, errchan
|
|
}
|
|
|
|
func (desc *PortTunnelDescription) validate() (err error) {
|
|
if desc.LocalPort <= 0 || desc.LocalPort > 0xFFFF {
|
|
return xerrors.Errorf("bad local port: %d", desc.LocalPort)
|
|
}
|
|
if desc.TargetPort > 0xFFFF {
|
|
return xerrors.Errorf("bad target port: %d", desc.TargetPort)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Tunnel opens new tunnels.
|
|
func (p *TunneledPortsService) Tunnel(ctx context.Context, options *TunnelOptions, descs ...*PortTunnelDescription) (tunneled []uint32, err error) {
|
|
var shouldNotify bool
|
|
p.cond.L.Lock()
|
|
defer p.cond.L.Unlock()
|
|
for _, desc := range descs {
|
|
descErr := desc.validate()
|
|
if descErr != nil {
|
|
if err == nil {
|
|
err = descErr
|
|
} else {
|
|
err = xerrors.Errorf("%s\n%s", err, descErr)
|
|
}
|
|
continue
|
|
}
|
|
tunnel, tunnelExists := p.tunnels[desc.LocalPort]
|
|
if !tunnelExists {
|
|
tunnel = &PortTunnel{
|
|
State: PortTunnelState{
|
|
Clients: make(map[string]uint32),
|
|
},
|
|
Conns: make(map[string]map[net.Conn]struct{}),
|
|
}
|
|
p.tunnels[desc.LocalPort] = tunnel
|
|
} else if options.SkipIfExists {
|
|
continue
|
|
}
|
|
tunnel.State.Desc = *desc
|
|
shouldNotify = true
|
|
tunneled = append(tunneled, desc.LocalPort)
|
|
}
|
|
if shouldNotify {
|
|
p.cond.Broadcast()
|
|
}
|
|
return tunneled, err
|
|
}
|
|
|
|
// CloseTunnel closes tunnels.
|
|
func (p *TunneledPortsService) CloseTunnel(ctx context.Context, localPorts ...uint32) (closedPorts []uint32, err error) {
|
|
var closed []*PortTunnel
|
|
p.cond.L.Lock()
|
|
for _, localPort := range localPorts {
|
|
tunnel, existsTunnel := p.tunnels[localPort]
|
|
if !existsTunnel {
|
|
continue
|
|
}
|
|
delete(p.tunnels, localPort)
|
|
closed = append(closed, tunnel)
|
|
closedPorts = append(closedPorts, localPort)
|
|
}
|
|
if len(closed) > 0 {
|
|
p.cond.Broadcast()
|
|
}
|
|
p.cond.L.Unlock()
|
|
for _, tunnel := range closed {
|
|
for _, conns := range tunnel.Conns {
|
|
for conn := range conns {
|
|
closeErr := conn.Close()
|
|
if closeErr == nil {
|
|
continue
|
|
}
|
|
if err == nil {
|
|
err = closeErr
|
|
} else {
|
|
err = xerrors.Errorf("%s\n%s", err, closeErr)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return closedPorts, err
|
|
}
|
|
|
|
// EstablishTunnel actually establishes the tunnel.
|
|
func (p *TunneledPortsService) EstablishTunnel(ctx context.Context, clientID string, localPort uint32, targetPort uint32) (net.Conn, error) {
|
|
p.cond.L.Lock()
|
|
defer p.cond.L.Unlock()
|
|
|
|
tunnel, tunnelExists := p.tunnels[localPort]
|
|
if tunnelExists {
|
|
expectedTargetPort, clientExists := tunnel.State.Clients[clientID]
|
|
if clientExists && expectedTargetPort != targetPort {
|
|
return nil, xerrors.Errorf("client '%s': %d:%d is already tunneling", clientID, localPort, targetPort)
|
|
}
|
|
} else {
|
|
return nil, xerrors.Errorf("client '%s': '%d' tunnel does not exist", clientID, localPort)
|
|
}
|
|
|
|
addr := net.JoinHostPort("localhost", strconv.FormatInt(int64(localPort), 10))
|
|
conn, err := net.Dial("tcp", addr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var result net.Conn
|
|
result = &tunnelConn{
|
|
Conn: conn,
|
|
onDidClose: func() {
|
|
p.cond.L.Lock()
|
|
defer p.cond.L.Unlock()
|
|
_, existsTunnel := p.tunnels[localPort]
|
|
if !existsTunnel {
|
|
return
|
|
}
|
|
delete(tunnel.Conns[clientID], result)
|
|
if len(tunnel.Conns[clientID]) == 0 {
|
|
delete(tunnel.State.Clients, clientID)
|
|
}
|
|
p.cond.Broadcast()
|
|
},
|
|
}
|
|
if tunnel.Conns[clientID] == nil {
|
|
tunnel.Conns[clientID] = make(map[net.Conn]struct{})
|
|
}
|
|
tunnel.Conns[clientID][result] = struct{}{}
|
|
tunnel.State.Clients[clientID] = targetPort
|
|
p.cond.Broadcast()
|
|
return result, nil
|
|
}
|
|
|
|
// Snapshot writes a snapshot to w.
|
|
func (p *TunneledPortsService) Snapshot(w io.Writer) {
|
|
p.mu.RLock()
|
|
defer p.mu.RUnlock()
|
|
|
|
localPorts := make([]uint32, 0, len(p.tunnels))
|
|
for k := range p.tunnels {
|
|
localPorts = append(localPorts, k)
|
|
}
|
|
sort.Slice(localPorts, func(i, j int) bool { return localPorts[i] < localPorts[j] })
|
|
|
|
for _, localPort := range localPorts {
|
|
tunnel := p.tunnels[localPort]
|
|
fmt.Fprintf(w, "Local Port: %d\n", tunnel.State.Desc.LocalPort)
|
|
fmt.Fprintf(w, "Target Port: %d\n", tunnel.State.Desc.TargetPort)
|
|
visibilty := api.TunnelVisiblity_name[int32(tunnel.State.Desc.Visibility)]
|
|
fmt.Fprintf(w, "Visibility: %s\n", visibilty)
|
|
for clientID, remotePort := range tunnel.State.Clients {
|
|
fmt.Fprintf(w, "Client: %s\n", clientID)
|
|
fmt.Fprintf(w, " Remote Port: %d\n", remotePort)
|
|
fmt.Fprintf(w, " Tunnel Count: %d\n", len(tunnel.Conns[clientID]))
|
|
}
|
|
fmt.Fprintf(w, "\n")
|
|
}
|
|
}
|