2023-02-15 15:39:20 +01:00

98 lines
2.8 KiB
Go

// Copyright (c) 2022 Gitpod GmbH. All rights reserved.
// Licensed under the GNU Affero General Public License (AGPL).
// See License.AGPL.txt in the project root for license information.
package auth
import (
"context"
"fmt"
"github.com/bufbuild/connect-go"
)
type Interceptor struct {
accessToken string
}
func (a *Interceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return connect.UnaryFunc(func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
if req.Spec().IsClient {
ctx = TokenToContext(ctx, NewAccessToken(a.accessToken))
req.Header().Add(authorizationHeaderKey, bearerPrefix+a.accessToken)
return next(ctx, req)
}
token, err := tokenFromRequest(ctx, req)
if err != nil {
return nil, err
}
return next(TokenToContext(ctx, token), req)
})
}
func (a *Interceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
return func(ctx context.Context, s connect.Spec) connect.StreamingClientConn {
ctx = TokenToContext(ctx, NewAccessToken(a.accessToken))
conn := next(ctx, s)
conn.RequestHeader().Add(authorizationHeaderKey, bearerPrefix+a.accessToken)
return conn
}
}
func (a *Interceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
return func(ctx context.Context, conn connect.StreamingHandlerConn) error {
token, err := tokenFromConn(ctx, conn)
if err != nil {
return err
}
return next(TokenToContext(ctx, token), conn)
}
}
// NewServerInterceptor creates a server-side interceptor which validates that an incoming request contains a Bearer Authorization header
func NewServerInterceptor() connect.Interceptor {
return &Interceptor{}
}
func tokenFromRequest(ctx context.Context, req connect.AnyRequest) (Token, error) {
headers := req.Header()
bearerToken, err := BearerTokenFromHeaders(headers)
if err == nil {
return NewAccessToken(bearerToken), nil
}
cookie := req.Header().Get("Cookie")
if cookie != "" {
return NewCookieToken(cookie), nil
}
return Token{}, connect.NewError(connect.CodeUnauthenticated, fmt.Errorf("No access token or cookie credentials available on request."))
}
func tokenFromConn(ctx context.Context, conn connect.StreamingHandlerConn) (Token, error) {
headers := conn.RequestHeader()
bearerToken, err := BearerTokenFromHeaders(headers)
if err == nil {
return NewAccessToken(bearerToken), nil
}
cookie := conn.RequestHeader().Get("Cookie")
if cookie != "" {
return NewCookieToken(cookie), nil
}
return Token{}, connect.NewError(connect.CodeUnauthenticated, fmt.Errorf("No access token or cookie credentials available on request."))
}
// NewClientInterceptor creates a client-side interceptor which injects token as a Bearer Authorization header
func NewClientInterceptor(accessToken string) connect.Interceptor {
return &Interceptor{
accessToken: accessToken,
}
}