2023-02-15 15:39:20 +01:00

49 lines
1.3 KiB
Go

// Copyright (c) 2023 Gitpod GmbH. All rights reserved.
// Licensed under the GNU Affero General Public License (AGPL).
// See License.AGPL.txt in the project root for license information.
package origin
import (
"context"
"github.com/bufbuild/connect-go"
)
func NewInterceptor() *Interceptor {
return &Interceptor{}
}
type Interceptor struct{}
func (i *Interceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return connect.UnaryFunc(func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
if req.Spec().IsClient {
req.Header().Add("Origin", FromContext(ctx))
} else {
origin := req.Header().Get("Origin")
ctx = ToContext(ctx, origin)
}
return next(ctx, req)
})
}
func (a *Interceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
return func(ctx context.Context, s connect.Spec) connect.StreamingClientConn {
conn := next(ctx, s)
conn.RequestHeader().Add("Origin", FromContext(ctx))
return conn
}
}
func (a *Interceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
return func(ctx context.Context, conn connect.StreamingHandlerConn) error {
origin := conn.RequestHeader().Get("Origin")
ctx = ToContext(ctx, origin)
return next(ctx, conn)
}
}