2022-12-08 13:05:19 -03:00

112 lines
2.6 KiB
Go

// Copyright (c) 2020 Gitpod GmbH. All rights reserved.
// Licensed under the GNU Affero General Public License (AGPL).
// See License.AGPL.txt in the project root for license information.
package dropwriter
import (
"io"
"sync"
"time"
)
// Clock abstracts time for the bucket limiter.
type Clock func() time.Time
// NewBucket creates a new bucket limiter with a realtime clock.
func NewBucket(capacity, refillRatePerSec int64) *Bucket {
return NewBucketClock(capacity, refillRatePerSec, time.Now)
}
// NewBucketClock produces a new bucket limiter with a custom clock. Useful for testing.
func NewBucketClock(capacity, refillRatePerSec int64, clock Clock) *Bucket {
return &Bucket{
clock: clock,
capacity: capacity,
refillRate: refillRatePerSec,
}
}
// Bucket implements a token bucket limiter.
type Bucket struct {
clock Clock
// capacity is the total token capacity of this bucket
capacity int64
// refillRate holds how many tokens we refill per second
refillRate int64
// mu syncs bucket access
mu sync.Mutex
// availableTokens is the total number of tokens currently available
availableTokens int64
// lastTick is the last time we adjusted the available token count
lastTick time.Time
}
func (b *Bucket) adjustTokens() {
b.mu.Lock()
defer b.mu.Unlock()
now := b.clock()
defer func() {
b.lastTick = now
}()
if b.lastTick.IsZero() {
// first adjustment/tick ever - set availableTokens to capacity
b.availableTokens = b.capacity
return
}
b.availableTokens += int64(now.Sub(b.lastTick).Seconds() * float64(b.refillRate))
if b.availableTokens > b.capacity {
b.availableTokens = b.capacity
}
}
// TakeAvailable attempts to remove req tokens from the bucket. If there are fewer tokens available
// all remaining tokens are removed and returned.
func (b *Bucket) TakeAvailable(req int64) int64 {
b.adjustTokens()
b.mu.Lock()
defer b.mu.Unlock()
grant := req
if grant > b.availableTokens {
grant = b.availableTokens
}
b.availableTokens -= grant
return grant
}
type writer struct {
w io.Writer
bucket *Bucket
}
func (w *writer) Write(buf []byte) (n int, err error) {
grant := w.bucket.TakeAvailable(int64(len(buf)))
n, err = w.w.Write(buf[:grant])
if err != nil {
return
}
// We act as though we had written the whole buffer. This is what actually implements
// the byte drop imposed by the bucket limiter. if we returned the correct number of bytes
// here the caller might err with ErrShortWrite or simply try again.
n = len(buf)
return
}
// Writer produces a new rate limited dropping writer.
func Writer(dst io.Writer, b *Bucket) io.Writer {
return &writer{w: dst, bucket: b}
}