Filip Troníček 624c79f9f7
Respond to /idp/keys with JSON (#17789)
* Set JSON mimetype for `/idp/keys`

* Fix typos

* Test for header presence

* Assert JSON for ` /.well-known/openid-configuration` as well
2023-05-31 14:45:05 +08:00

243 lines
5.8 KiB
Go

// Copyright (c) 2023 Gitpod GmbH. All rights reserved.
// Licensed under the GNU Affero General Public License (AGPL).
// See License.AGPL.txt in the project root for license information.
package identityprovider
import (
"context"
"crypto/rsa"
"encoding/json"
"fmt"
"math/rand"
"strings"
"sync"
"time"
"github.com/gitpod-io/gitpod/common-go/log"
"github.com/redis/go-redis/v9"
"gopkg.in/square/go-jose.v2"
)
// KeyCache caches public keys to ensure they're returned with the JWKS as long
// as there are valid tokens out there using those keys.
//
// PoC Note: in production this cache would likely be implemented using Redis or the database.
type KeyCache interface {
// Set rotates the current key
Set(ctx context.Context, current *rsa.PrivateKey) error
// Signer produces a new key signer or nil if Set() hasn't been called yet
Signer(ctx context.Context) (jose.Signer, error)
// PublicKeys returns all un-expired public keys as JSON-encoded *jose.JSONWebKeySet.
// This function returns the JSON-encoded form directly instead of the *jose.JSONWebKeySet
// to allow for persisted JSON implementations of this interface.
PublicKeys(ctx context.Context) ([]byte, error)
}
type inMemoryKey struct {
ID string
Created time.Time
Key *rsa.PublicKey
}
func NewInMemoryCache() *InMemoryCache {
return &InMemoryCache{
keys: make(map[string]*inMemoryKey),
}
}
type InMemoryCache struct {
mu sync.RWMutex
current *rsa.PrivateKey
currentID string
keys map[string]*inMemoryKey
}
// Set rotates the current key
func (imc *InMemoryCache) Set(ctx context.Context, current *rsa.PrivateKey) error {
imc.mu.Lock()
defer imc.mu.Unlock()
id := fmt.Sprintf("id%d%d", time.Now().Unix(), rand.Int())
imc.currentID = id
imc.current = current
imc.keys[id] = &inMemoryKey{
ID: id,
Created: time.Now(),
Key: &current.PublicKey,
}
return nil
}
// Signer produces a new key signer or nil if Set() hasn't been called yet
func (imc *InMemoryCache) Signer(ctx context.Context) (jose.Signer, error) {
if imc.current == nil {
return nil, nil
}
return jose.NewSigner(jose.SigningKey{
Algorithm: jose.RS256,
Key: imc.current,
}, nil)
}
// PublicKeys returns all un-expired public keys
func (imc *InMemoryCache) PublicKeys(ctx context.Context) ([]byte, error) {
imc.mu.RLock()
defer imc.mu.RUnlock()
var jwks jose.JSONWebKeySet
for _, key := range imc.keys {
jwks.Keys = append(jwks.Keys, jose.JSONWebKey{
Key: key.Key,
KeyID: key.ID,
Algorithm: string(jose.RS256),
})
}
return json.Marshal(jwks)
}
const (
redisCacheDefaultTTL = 1 * time.Hour
redisIDPKeyPrefix = "idp:keys:"
)
func NewRedisCache(client *redis.Client) *RedisCache {
return &RedisCache{
Client: client,
keyID: defaultKeyID,
}
}
func defaultKeyID(current *rsa.PrivateKey) string {
return fmt.Sprintf("id-%d-%d", time.Now().UnixMicro(), rand.Int())
}
type RedisCache struct {
Client *redis.Client
keyID func(current *rsa.PrivateKey) string
mu sync.RWMutex
current *rsa.PrivateKey
currentID string
}
// PublicKeys implements KeyCache
func (rc *RedisCache) PublicKeys(ctx context.Context) ([]byte, error) {
var (
res = []byte("{\"keys\":[")
first = true
hasCurrentKey = false
)
if rc.current != nil && rc.currentID != "" {
hasCurrentKey = true
fc, err := serializePublicKeyAsJSONWebKey(rc.currentID, &rc.current.PublicKey)
if err != nil {
return nil, err
}
res = append(res, fc...)
first = false
}
iter := rc.Client.Scan(ctx, 0, redisIDPKeyPrefix+"*", 0).Iterator()
for iter.Next(ctx) {
idx := iter.Val()
if hasCurrentKey && strings.HasSuffix(idx, rc.currentID) {
// We've already added the public key we hold in memory
continue
}
key, err := rc.Client.Get(ctx, idx).Result()
if err != nil {
return nil, err
}
if !first {
res = append(res, []byte(",")...)
}
res = append(res, []byte(key)...)
first = false
}
if err := iter.Err(); err != nil {
return nil, err
}
res = append(res, []byte("]}")...)
return res, nil
}
func serializePublicKeyAsJSONWebKey(keyID string, key *rsa.PublicKey) ([]byte, error) {
publicKey := jose.JSONWebKey{
Key: key,
KeyID: keyID,
Algorithm: string(jose.RS256),
}
return json.Marshal(publicKey)
}
// Set implements KeyCache
func (rc *RedisCache) Set(ctx context.Context, current *rsa.PrivateKey) error {
rc.mu.Lock()
defer rc.mu.Unlock()
err := rc.persistPublicKey(ctx, current)
if err != nil {
return err
}
rc.currentID = rc.keyID(current)
rc.current = current
return nil
}
func (rc *RedisCache) persistPublicKey(ctx context.Context, current *rsa.PrivateKey) error {
id := rc.keyID(current)
publicKeyJSON, err := serializePublicKeyAsJSONWebKey(id, &current.PublicKey)
if err != nil {
return err
}
redisKey := fmt.Sprintf("%s%s", redisIDPKeyPrefix, id)
err = rc.Client.Set(ctx, redisKey, string(publicKeyJSON), redisCacheDefaultTTL).Err()
if err != nil {
return err
}
return nil
}
// Signer implements KeyCache
func (rc *RedisCache) Signer(ctx context.Context) (jose.Signer, error) {
if rc.current == nil {
return nil, nil
}
resp := rc.Client.Expire(ctx, redisIDPKeyPrefix+rc.currentID, redisCacheDefaultTTL)
if err := resp.Err(); err != nil {
log.WithField("keyID", rc.currentID).WithError(err).Warn("cannot extend cached IDP public key TTL")
}
if !resp.Val() {
log.WithField("keyID", rc.currentID).Warn("cannot extend cached IDP public key TTL - trying to repersist")
err := rc.persistPublicKey(ctx, rc.current)
if err != nil {
log.WithField("keyID", rc.currentID).WithError(err).Error("cannot repersist public key")
return nil, err
}
}
return jose.NewSigner(jose.SigningKey{
Algorithm: jose.RS256,
Key: rc.current,
}, &jose.SignerOptions{
ExtraHeaders: map[jose.HeaderKey]interface{}{
jose.HeaderKey("kid"): rc.currentID,
},
})
}
var _ KeyCache = ((*RedisCache)(nil))