diff --git a/test/src/lib.rs b/test/src/lib.rs index 91b082c..4f28ca7 100644 --- a/test/src/lib.rs +++ b/test/src/lib.rs @@ -28,6 +28,7 @@ mod kv; mod put_raw; mod queue; mod r2; +mod rate_limit; mod request; mod router; mod secret_store; diff --git a/test/src/rate_limit.rs b/test/src/rate_limit.rs new file mode 100644 index 0000000..d8665d1 --- /dev/null +++ b/test/src/rate_limit.rs @@ -0,0 +1,84 @@ +use super::SomeSharedData; +use std::collections::HashMap; +use worker::{js_sys, Env, Request, Response, Result}; + +#[worker::send] +pub async fn handle_rate_limit_check( + _req: Request, + env: Env, + _data: SomeSharedData, +) -> Result { + let rate_limiter = env.rate_limiter("TEST_RATE_LIMITER")?; + + // Use a fixed key for testing + let outcome = rate_limiter.limit("test-key".to_string()).await?; + + Response::from_json(&serde_json::json!({ + "success": outcome.success, + })) +} + +#[worker::send] +pub async fn handle_rate_limit_with_key( + req: Request, + env: Env, + _data: SomeSharedData, +) -> Result { + let uri = req.url()?; + let segments = uri.path_segments().unwrap().collect::>(); + let key = segments.get(2).unwrap_or(&"default-key"); + + let rate_limiter = env.rate_limiter("TEST_RATE_LIMITER")?; + let outcome = rate_limiter.limit(key.to_string()).await?; + + Response::from_json(&serde_json::json!({ + "success": outcome.success, + "key": key, + })) +} + +#[worker::send] +pub async fn handle_rate_limit_bulk_test( + _req: Request, + env: Env, + _data: SomeSharedData, +) -> Result { + let rate_limiter = env.rate_limiter("TEST_RATE_LIMITER")?; + + // Test multiple requests to verify rate limiting behavior + let mut results = Vec::new(); + for i in 0..15 { + let key = format!("bulk-test-{}", i % 3); // Use 3 different keys + let outcome = rate_limiter.limit(key.clone()).await?; + results.push(serde_json::json!({ + "index": i, + "key": key, + "success": outcome.success, + })); + } + + Response::from_json(&serde_json::json!({ + "results": results, + })) +} + +#[worker::send] +pub async fn handle_rate_limit_reset( + _req: Request, + env: Env, + _data: SomeSharedData, +) -> Result { + let rate_limiter = env.rate_limiter("TEST_RATE_LIMITER")?; + + // Use a unique key to avoid interference with other tests + let key = format!("reset-test-{}", js_sys::Date::now()); + + // Make multiple requests with the same key + let mut outcomes = HashMap::new(); + for i in 0..12 { + let outcome = rate_limiter.limit(key.clone()).await?; + outcomes.insert(format!("request_{}", i + 1), outcome.success); + } + + Response::from_json(&outcomes) +} diff --git a/test/src/router.rs b/test/src/router.rs index ff1d2c2..217af7e 100644 --- a/test/src/router.rs +++ b/test/src/router.rs @@ -1,7 +1,7 @@ use crate::{ alarm, analytics_engine, assets, auto_response, cache, container, counter, d1, durable, fetch, - form, js_snippets, kv, put_raw, queue, r2, request, secret_store, service, socket, sql_counter, - sql_iterator, user, ws, SomeSharedData, GLOBAL_STATE, + form, js_snippets, kv, put_raw, queue, r2, rate_limit, request, secret_store, service, socket, + sql_counter, sql_iterator, user, ws, SomeSharedData, GLOBAL_STATE, }; #[cfg(feature = "http")] use std::convert::TryInto; @@ -227,6 +227,10 @@ macro_rules! add_routes ( add_route!($obj, get, sync, "/test-panic", handle_test_panic); add_route!($obj, post, "/container/echo", container::handle_container); add_route!($obj, get, "/container/ws", container::handle_container); + add_route!($obj, get, "/rate-limit/check", rate_limit::handle_rate_limit_check); + add_route!($obj, get, format_route!("/rate-limit/key/{}", "key"), rate_limit::handle_rate_limit_with_key); + add_route!($obj, get, "/rate-limit/bulk-test", rate_limit::handle_rate_limit_bulk_test); + add_route!($obj, get, "/rate-limit/reset", rate_limit::handle_rate_limit_reset); }); #[cfg(feature = "http")] diff --git a/test/tests/mf.ts b/test/tests/mf.ts index 2f3299a..7a63c7d 100644 --- a/test/tests/mf.ts +++ b/test/tests/mf.ts @@ -112,6 +112,14 @@ const mf_instance = new Miniflare({ HTTP_ANALYTICS: { scriptName: "mini-analytics-engine" // mock out analytics engine binding to the "mini-analytics-engine" worker } + }, + ratelimits: { + TEST_RATE_LIMITER: { + simple: { + limit: 10, + period: 60, + } + } } }, { diff --git a/test/tests/rate_limit.spec.ts b/test/tests/rate_limit.spec.ts new file mode 100644 index 0000000..f86fa00 --- /dev/null +++ b/test/tests/rate_limit.spec.ts @@ -0,0 +1,169 @@ +import { describe, test, expect } from "vitest"; +import { mf, mfUrl } from "./mf"; + +describe("rate limit", () => { + test("basic rate limit check", async () => { + const resp = await mf.dispatchFetch(`${mfUrl}rate-limit/check`); + expect(resp.status).toBe(200); + const data = await resp.json() as { success: boolean }; + expect(data).toHaveProperty("success"); + expect(data.success).toBe(true); + }); + + test("rate limit with custom key", async () => { + const key = "test-key-123"; + const resp = await mf.dispatchFetch(`${mfUrl}rate-limit/key/${key}`); + expect(resp.status).toBe(200); + const data = await resp.json() as { success: boolean; key: string }; + expect(data).toHaveProperty("success"); + expect(data).toHaveProperty("key"); + expect(data.key).toBe(key); + expect(data.success).toBe(true); + }); + + test("different keys have independent limits", async () => { + // Test that different keys have separate rate limits + const key1 = "user-1"; + const key2 = "user-2"; + + const resp1 = await mf.dispatchFetch(`${mfUrl}rate-limit/key/${key1}`); + const resp2 = await mf.dispatchFetch(`${mfUrl}rate-limit/key/${key2}`); + + expect(resp1.status).toBe(200); + expect(resp2.status).toBe(200); + + const data1 = await resp1.json() as { success: boolean; key: string }; + const data2 = await resp2.json() as { success: boolean; key: string }; + + expect(data1.success).toBe(true); + expect(data2.success).toBe(true); + expect(data1.key).toBe(key1); + expect(data2.key).toBe(key2); + }); + + test("bulk rate limit test", async () => { + const resp = await mf.dispatchFetch(`${mfUrl}rate-limit/bulk-test`); + expect(resp.status).toBe(200); + const data = await resp.json() as { results: Array<{ index: number; key: string; success: boolean }> }; + expect(data).toHaveProperty("results"); + expect(Array.isArray(data.results)).toBe(true); + expect(data.results.length).toBe(15); + + // Check that results have the expected structure + data.results.forEach((result, index: number) => { + expect(result).toHaveProperty("index"); + expect(result).toHaveProperty("key"); + expect(result).toHaveProperty("success"); + expect(result.index).toBe(index); + expect(typeof result.success).toBe("boolean"); + }); + + // We're using 3 different keys (bulk-test-0, bulk-test-1, bulk-test-2) + // with a limit of 10 per 60 seconds. Each key is used 5 times (15 requests / 3 keys). + // All requests should succeed since each key stays under the limit of 10. + + // Group results by key + const resultsByKey: Record> = {}; + data.results.forEach((result) => { + if (!resultsByKey[result.key]) { + resultsByKey[result.key] = []; + } + resultsByKey[result.key].push(result); + }); + + // Should have exactly 3 keys + expect(Object.keys(resultsByKey).length).toBe(3); + + // Each key should have 5 requests, all successful (under limit of 10) + Object.entries(resultsByKey).forEach(([key, results]) => { + expect(results.length).toBe(5); + results.forEach((result) => { + expect(result.success).toBe(true); + }); + }); + }); + + test("rate limit reset with unique keys", async () => { + const resp = await mf.dispatchFetch(`${mfUrl}rate-limit/reset`); + expect(resp.status).toBe(200); + const data = await resp.json() as Record; + + // Should have 12 request results + expect(Object.keys(data).length).toBe(12); + + // Check that we have the expected keys + for (let i = 1; i <= 12; i++) { + expect(data).toHaveProperty(`request_${i}`); + expect(typeof data[`request_${i}`]).toBe("boolean"); + } + + // With a limit of 10 per 60 seconds, the first 10 requests MUST succeed + // and requests 11 and 12 MUST fail + for (let i = 1; i <= 10; i++) { + expect(data[`request_${i}`]).toBe(true); + } + + // Requests 11 and 12 must be rate limited + expect(data["request_11"]).toBe(false); + expect(data["request_12"]).toBe(false); + }); + + test("multiple rapid requests with same key", async () => { + // Generate a unique key for this test + const testKey = `rapid-test-${Date.now()}`; + + // Make multiple rapid requests with the same key + const promises = []; + for (let i = 0; i < 5; i++) { + promises.push(mf.dispatchFetch(`${mfUrl}rate-limit/key/${testKey}`)); + } + + const responses = await Promise.all(promises); + + // All responses should be successful (200 status) + responses.forEach(resp => { + expect(resp.status).toBe(200); + }); + + // Parse the responses + const results = await Promise.all(responses.map(r => r.json())) as Array<{ success: boolean; key: string }>; + + // All should have the same key + results.forEach(data => { + expect(data.key).toBe(testKey); + expect(data).toHaveProperty("success"); + }); + + // With limit of 10, all 5 requests should succeed + results.forEach((data) => { + expect(data.success).toBe(true); + }); + }); + + test("sequential requests enforce rate limit", async () => { + // Generate a unique key for this test to avoid interference + const testKey = `sequential-test-${Date.now()}`; + + // Make 15 sequential requests with the same key + // With a limit of 10 per 60 seconds, first 10 should succeed, rest should fail + const results: Array<{ success: boolean; key: string }> = []; + for (let i = 0; i < 15; i++) { + const resp = await mf.dispatchFetch(`${mfUrl}rate-limit/key/${testKey}`); + expect(resp.status).toBe(200); + const data = await resp.json() as { success: boolean; key: string }; + results.push(data); + } + + // Verify first 10 requests succeed + for (let i = 0; i < 10; i++) { + expect(results[i].success).toBe(true); + expect(results[i].key).toBe(testKey); + } + + // Verify requests 11-15 are rate limited + for (let i = 10; i < 15; i++) { + expect(results[i].success).toBe(false); + expect(results[i].key).toBe(testKey); + } + }); +}); diff --git a/test/wrangler.toml b/test/wrangler.toml index 1a1c439..8d0af2c 100644 --- a/test/wrangler.toml +++ b/test/wrangler.toml @@ -84,3 +84,8 @@ secret_name = "secret-name" class_name = "EchoContainer" image = "./container-echo/Dockerfile" max_instances = 1 + +[[ratelimits]] +name = "TEST_RATE_LIMITER" +namespace_id = "1" +simple = { limit = 10, period = 60 } diff --git a/worker/src/env.rs b/worker/src/env.rs index 550e3e5..780c258 100644 --- a/worker/src/env.rs +++ b/worker/src/env.rs @@ -4,6 +4,7 @@ use crate::analytics_engine::AnalyticsEngineDataset; #[cfg(feature = "d1")] use crate::d1::D1Database; use crate::kv::KvStore; +use crate::rate_limit::RateLimiter; use crate::Ai; #[cfg(feature = "queue")] use crate::Queue; @@ -122,6 +123,11 @@ impl Env { pub fn secret_store(&self, binding: &str) -> Result { self.get_binding(binding) } + + /// Access a Rate Limiter by the binding name configured in your wrangler.toml file. + pub fn rate_limiter(&self, binding: &str) -> Result { + self.get_binding(binding) + } } pub trait EnvBinding: Sized + JsCast { diff --git a/worker/src/lib.rs b/worker/src/lib.rs index bfdbb3a..f3252a2 100644 --- a/worker/src/lib.rs +++ b/worker/src/lib.rs @@ -187,7 +187,7 @@ pub use crate::kv::{KvError, KvStore}; #[cfg(feature = "queue")] pub use crate::queue::*; pub use crate::r2::*; -pub use crate::rate_limit::RateLimiter; +pub use crate::rate_limit::{RateLimitOutcome, RateLimiter}; pub use crate::request::{FromRequest, Request}; pub use crate::request_init::*; pub use crate::response::{EncodeBody, IntoResponse, Response, ResponseBody, ResponseBuilder}; diff --git a/worker/src/router.rs b/worker/src/router.rs index 60bf865..8a52cfb 100644 --- a/worker/src/router.rs +++ b/worker/src/router.rs @@ -7,6 +7,7 @@ use crate::{ durable::ObjectNamespace, env::{Env, Secret, Var}, http::Method, + rate_limit::RateLimiter, request::Request, response::Response, Bucket, Fetcher, KvStore, Result, @@ -118,6 +119,11 @@ impl RouteContext { pub fn d1(&self, binding: &str) -> Result { self.env.d1(binding) } + + /// Access a Rate Limiter by the binding name configured in your wrangler.toml file. + pub fn rate_limiter(&self, binding: &str) -> Result { + self.env.rate_limiter(binding) + } } impl Router<'_, ()> {