2023-02-15 15:39:20 +01:00

93 lines
2.7 KiB
Go

// Copyright (c) 2022 Gitpod GmbH. All rights reserved.
// Licensed under the GNU Affero General Public License (AGPL).
// See License.AGPL.txt in the project root for license information.
package proxy
import (
"context"
"net/url"
"testing"
gitpod "github.com/gitpod-io/gitpod/gitpod-protocol"
"github.com/gitpod-io/gitpod/public-api-server/pkg/auth"
"github.com/gitpod-io/gitpod/public-api-server/pkg/origin"
"github.com/golang/mock/gomock"
lru "github.com/hashicorp/golang-lru"
"github.com/stretchr/testify/require"
)
func TestConnectionPool(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
srv := gitpod.NewMockAPIInterface(ctrl)
cache, err := lru.New(2)
require.NoError(t, err)
pool := &ConnectionPool{
cache: cache,
connConstructor: func(ctx context.Context, token auth.Token) (gitpod.APIInterface, error) {
return srv, nil
},
}
fooToken := auth.NewAccessToken("foo")
barToken := auth.NewAccessToken("bar")
bazToken := auth.NewAccessToken("baz")
_, err = pool.Get(context.Background(), fooToken)
require.NoError(t, err)
require.Equal(t, 1, pool.cache.Len())
_, err = pool.Get(context.Background(), barToken)
require.NoError(t, err)
require.Equal(t, 2, pool.cache.Len())
_, err = pool.Get(context.Background(), bazToken)
require.NoError(t, err)
require.Equal(t, 2, pool.cache.Len(), "must keep only last two connectons")
require.True(t, pool.cache.Contains(pool.cacheKey(barToken, "")))
require.True(t, pool.cache.Contains(pool.cacheKey(bazToken, "")))
}
func TestConnectionPool_ByDistinctOrigins(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
srv := gitpod.NewMockAPIInterface(ctrl)
cache, err := lru.New(2)
require.NoError(t, err)
pool := &ConnectionPool{
cache: cache,
connConstructor: func(ctx context.Context, token auth.Token) (gitpod.APIInterface, error) {
return srv, nil
},
}
token := auth.NewAccessToken("foo")
ctxWithOriginA := origin.ToContext(context.Background(), "originA")
ctxWithOriginB := origin.ToContext(context.Background(), "originB")
_, err = pool.Get(ctxWithOriginA, token)
require.NoError(t, err)
require.Equal(t, 1, pool.cache.Len())
_, err = pool.Get(ctxWithOriginB, token)
require.NoError(t, err)
require.Equal(t, 2, pool.cache.Len())
}
func TestEndpointBasedOnToken(t *testing.T) {
u, err := url.Parse("wss://server:3000")
require.NoError(t, err)
endpointForAccessToken, err := getEndpointBasedOnToken(auth.NewAccessToken("foo"), u)
require.NoError(t, err)
require.Equal(t, "wss://server:3000/v1", endpointForAccessToken)
endpointForCookie, err := getEndpointBasedOnToken(auth.NewCookieToken("foo"), u)
require.NoError(t, err)
require.Equal(t, "wss://server:3000/gitpod", endpointForCookie)
}