gitpod/components/openvsx-proxy/pkg/openvsxproxy_test.go
2022-12-08 13:05:19 -03:00

147 lines
4.3 KiB
Go

// Copyright (c) 2020 Gitpod GmbH. All rights reserved.
// Licensed under the GNU Affero General Public License (AGPL).
// See License.AGPL.txt in the project root for license information.
package pkg
import (
"bytes"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/http/httputil"
"net/url"
"testing"
)
func createFrontend(backendURL string, isDisabledCache bool) (*httptest.Server, *OpenVSXProxy) {
u, _ := url.Parse(backendURL)
cfg := &Config{
URLUpstream: backendURL,
}
if !isDisabledCache {
cfg.AllowCacheDomain = []string{u.Host}
}
openVSXProxy := &OpenVSXProxy{Config: cfg}
openVSXProxy.Setup()
proxy := httputil.NewSingleHostReverseProxy(openVSXProxy.defaultUpstreamURL)
proxy.ModifyResponse = openVSXProxy.ModifyResponse
handler := http.HandlerFunc(openVSXProxy.Handler(proxy))
frontend := httptest.NewServer(handler)
cfg.URLLocal = frontend.URL
return frontend, openVSXProxy
}
func TestAddResponseToCache(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
bodyBytes, _ := io.ReadAll(r.Body)
rw.Header().Set("Content-Type", "application/json")
rw.Write([]byte(fmt.Sprintf("Hello %s!", string(bodyBytes))))
}))
defer backend.Close()
frontend, openVSXProxy := createFrontend(backend.URL, false)
defer frontend.Close()
frontendClient := frontend.Client()
requestBody := backend.URL
req, _ := http.NewRequest("POST", frontend.URL, bytes.NewBuffer([]byte(requestBody)))
req.Close = true
_, err := frontendClient.Do(req)
if err != nil {
t.Fatal(err)
}
key := fmt.Sprintf("POST / %d %s", len(requestBody), openVSXProxy.hash([]byte(requestBody)))
if _, err = openVSXProxy.cacheManager.Get(key); err != nil {
t.Error(err)
}
if _, ok, err := openVSXProxy.ReadCache(key); ok == false || err != nil {
t.Errorf("key not found or error: %v", err)
}
}
func TestServeFromCacheOnUpstreamError(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusInternalServerError)
}))
defer backend.Close()
frontend, openVSXProxy := createFrontend(backend.URL, false)
defer frontend.Close()
requestBody := "Request Body Foo"
key := fmt.Sprintf("POST / %d %s", len(requestBody), openVSXProxy.hash([]byte(requestBody)))
expectedHeader := make(map[string][]string)
expectedHeader["X-Test"] = []string{"Foo Bar"}
expectedResponse := "Response Body Baz"
expectedStatus := 200
openVSXProxy.StoreCache(key, &CacheObject{
Header: expectedHeader,
Body: []byte(expectedResponse),
StatusCode: expectedStatus,
})
frontendClient := frontend.Client()
req, _ := http.NewRequest("POST", frontend.URL, bytes.NewBuffer([]byte(requestBody)))
req.Close = true
res, err := frontendClient.Do(req)
if err != nil {
t.Fatal(err)
}
if res.StatusCode != expectedStatus {
t.Errorf("got status %d; expected %d", res.StatusCode, expectedStatus)
}
if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expectedResponse {
t.Errorf("got body '%s'; expected '%s'", string(bodyBytes), expectedResponse)
}
if h := res.Header.Get("X-Test"); h != "Foo Bar" {
t.Errorf("got header '%s'; expected '%s'", h, "Foo Bar")
}
}
func TestServeFromNonCacheOnUpstreamError(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusInternalServerError)
}))
defer backend.Close()
frontend, openVSXProxy := createFrontend(backend.URL, true)
defer frontend.Close()
requestBody := "Request Body Foo"
key := fmt.Sprintf("POST / %d %s", len(requestBody), openVSXProxy.hash([]byte(requestBody)))
expectedHeader := make(map[string][]string)
expectedResponse := ""
expectedStatus := 500
openVSXProxy.StoreCache(key, &CacheObject{
Header: expectedHeader,
Body: []byte(expectedResponse),
StatusCode: expectedStatus,
})
frontendClient := frontend.Client()
req, _ := http.NewRequest("POST", frontend.URL, bytes.NewBuffer([]byte(requestBody)))
req.Close = true
res, err := frontendClient.Do(req)
if err != nil {
t.Fatal(err)
}
if res.StatusCode != expectedStatus {
t.Errorf("got status %d; expected %d", res.StatusCode, expectedStatus)
}
if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expectedResponse {
t.Errorf("got body '%s'; expected '%s'", string(bodyBytes), expectedResponse)
}
}