2022-12-08 13:05:19 -03:00

209 lines
5.5 KiB
Go

// Copyright (c) 2020 Gitpod GmbH. All rights reserved.
// Licensed under the GNU Affero General Public License (AGPL).
// See License.AGPL.txt in the project root for license information.
package grpc
import (
"context"
"strings"
"time"
"golang.org/x/time/rate"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"github.com/gitpod-io/gitpod/common-go/util"
lru "github.com/hashicorp/golang-lru"
"github.com/prometheus/client_golang/prometheus"
)
type keyFunc func(req interface{}) (string, error)
// RateLimit configures the reate limit for a function
type RateLimit struct {
Block bool `json:"block"`
BucketSize uint `json:"bucketSize"`
RefillInterval util.Duration `json:"refillInterval"`
KeyCacheSize uint `json:"keyCacheSize,omitempty"`
Key string `json:"key,omitempty"`
}
func (r RateLimit) Limiter() *rate.Limiter {
return rate.NewLimiter(rate.Every(time.Duration(r.RefillInterval)), int(r.BucketSize))
}
// NewRatelimitingInterceptor creates a new rate limiting interceptor
func NewRatelimitingInterceptor(f map[string]RateLimit) RatelimitingInterceptor {
callCounter := prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: "grpc",
Subsystem: "server",
Name: "rate_limiter_calls_total",
}, []string{"grpc_method", "rate_limited"})
cacheHitCounter := prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: "grpc",
Subsystem: "server",
Name: "rate_limiter_cache_hit_total",
}, []string{"grpc_method"})
funcs := make(map[string]*ratelimitedFunction, len(f))
for name, fnc := range f {
var (
keyedLimit *lru.Cache
key keyFunc
)
if fnc.Key != "" && fnc.KeyCacheSize > 0 {
keyedLimit, _ = lru.New(int(fnc.KeyCacheSize))
key = fieldAccessKey(fnc.Key)
}
funcs[name] = &ratelimitedFunction{
RateLimit: fnc,
GlobalLimit: fnc.Limiter(),
Key: key,
KeyedLimit: keyedLimit,
RateLimitedTotal: callCounter.WithLabelValues(name, "true"),
NotRateLimitedTotal: callCounter.WithLabelValues(name, "false"),
CacheMissTotal: cacheHitCounter.WithLabelValues(name),
}
}
return RatelimitingInterceptor{
functions: funcs,
collectors: []prometheus.Collector{callCounter, cacheHitCounter},
}
}
func fieldAccessKey(key string) keyFunc {
return func(req interface{}) (string, error) {
msg, ok := req.(proto.Message)
if !ok {
return "", status.Errorf(codes.Internal, "request was not a protobuf message")
}
val, ok := getFieldValue(msg.ProtoReflect(), strings.Split(key, "."))
if !ok {
return "", status.Errorf(codes.Internal, "Field %s does not exist in message. This is a rate limiting configuration error.", key)
}
return val, nil
}
}
func getFieldValue(msg protoreflect.Message, path []string) (val string, ok bool) {
if len(path) == 0 {
return "", false
}
field := msg.Descriptor().Fields().ByName(protoreflect.Name(path[0]))
if field == nil {
return "", false
}
if len(path) > 1 {
if field.Kind() != protoreflect.MessageKind {
// we should go deeper but the field is not a message
return "", false
}
child := msg.Get(field).Message()
return getFieldValue(child, path[1:])
}
if field.Kind() != protoreflect.StringKind {
// we only support string fields
return "", false
}
return msg.Get(field).String(), true
}
// RatelimitingInterceptor limits how often a gRPC function may be called. If the limit has been
// exceeded, we'll return resource exhausted.
type RatelimitingInterceptor struct {
functions map[string]*ratelimitedFunction
collectors []prometheus.Collector
}
var _ prometheus.Collector = RatelimitingInterceptor{}
func (r RatelimitingInterceptor) Describe(d chan<- *prometheus.Desc) {
for _, c := range r.collectors {
c.Describe(d)
}
}
func (r RatelimitingInterceptor) Collect(m chan<- prometheus.Metric) {
for _, c := range r.collectors {
c.Collect(m)
}
}
type counter interface {
Inc()
}
type ratelimitedFunction struct {
RateLimit RateLimit
GlobalLimit *rate.Limiter
Key keyFunc
KeyedLimit *lru.Cache
RateLimitedTotal counter
NotRateLimitedTotal counter
CacheMissTotal counter
}
// UnaryInterceptor creates a unary interceptor that implements the rate limiting
func (r RatelimitingInterceptor) UnaryInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
f, ok := r.functions[info.FullMethod]
if !ok {
return handler(ctx, req)
}
var limit *rate.Limiter
if f.Key == nil {
limit = f.GlobalLimit
} else {
key, err := f.Key(req)
if err != nil {
return nil, err
}
found, _ := f.KeyedLimit.ContainsOrAdd(key, f.RateLimit.Limiter())
if !found && f.CacheMissTotal != nil {
f.CacheMissTotal.Inc()
}
v, _ := f.KeyedLimit.Get(key)
limit = v.(*rate.Limiter)
}
var blocked bool
defer func() {
if blocked && f.RateLimitedTotal != nil {
f.RateLimitedTotal.Inc()
} else if !blocked && f.NotRateLimitedTotal != nil {
f.NotRateLimitedTotal.Inc()
}
}()
if f.RateLimit.Block {
err := limit.Wait(ctx)
if err == context.Canceled {
blocked = true
return nil, err
}
if err != nil {
blocked = true
return nil, status.Error(codes.ResourceExhausted, err.Error())
}
} else if !limit.Allow() {
blocked = true
return nil, status.Error(codes.ResourceExhausted, "too many requests")
}
return handler(ctx, req)
}
}