Milan Pavlik d069f76edc
[public-api] Refactor JWT Sign/Verify to be reusable for OIDC - WEB-206 (#17327)
* [public-api] Refactor JWT Sign/Verify to be reusable for OIDC

* fix
2023-04-24 15:14:45 +08:00

102 lines
2.2 KiB
Go

// Copyright (c) 2023 Gitpod GmbH. All rights reserved.
// Licensed under the GNU Affero General Public License (AGPL).
// See License.AGPL.txt in the project root for license information.
package jws
import (
"crypto"
"crypto/rsa"
"errors"
"fmt"
"github.com/golang-jwt/jwt/v5"
)
func NewRSA256(keyset KeySet) (*RSA256, error) {
verifier, err := NewRSA256VerifierFromKeySet(keyset)
if err != nil {
return nil, err
}
return &RSA256{
RSA256Signer: &RSA256Signer{
Signing: keyset.Signing,
},
RSA256Verifier: verifier,
}, nil
}
type RSA256 struct {
*RSA256Signer
*RSA256Verifier
}
type RSA256Signer struct {
Signing Key
}
func (s *RSA256Signer) Sign(token *jwt.Token) (string, error) {
if token.Method != jwt.SigningMethodRS256 {
return "", errors.New("invalid signing method, token must use RSA256")
}
token.Header[KeyIDName] = s.Signing.ID
signed, err := token.SignedString(s.Signing.Private)
if err != nil {
return "", fmt.Errorf("failed to sign jwt: %w", err)
}
return signed, nil
}
func NewRSA256VerifierFromKeySet(keyset KeySet) (*RSA256Verifier, error) {
keys := map[string]*rsa.PrivateKey{
keyset.Signing.ID: keyset.Signing.Private,
}
for _, v := range keyset.Validating {
if _, found := keys[v.ID]; found {
return nil, errors.New("duplicate keys when constructing JWT")
}
keys[v.ID] = v.Private
}
return &RSA256Verifier{
Keys: keys,
}, nil
}
type RSA256Verifier struct {
// Keys by Key ID
Keys map[string]*rsa.PrivateKey
}
func (v *RSA256Verifier) Verify(token string, claims jwt.Claims, opts ...jwt.ParserOption) (*jwt.Token, error) {
parsed, err := jwt.ParseWithClaims(token, claims, jwt.Keyfunc(func(t *jwt.Token) (interface{}, error) {
return v.getPublicKeyFromKeyID(t)
}), opts...)
if err != nil {
return nil, fmt.Errorf("failed to parse jwt: %w", err)
}
return parsed, nil
}
func (v *RSA256Verifier) getPublicKeyFromKeyID(token *jwt.Token) (crypto.PublicKey, error) {
keyID, ok := token.Header[KeyIDName].(string)
if !ok {
return nil, fmt.Errorf("missing key id: %w", ErrInvalidKeyID)
}
key, ok := v.Keys[keyID]
if !ok {
return nil, fmt.Errorf("unknown key id %s: %w", keyID, ErrInvalidKeyID)
}
return key.Public(), nil
}