diff --git a/src/lib/__tests__/rateLimit.test.ts b/src/lib/__tests__/rateLimit.test.ts index 1894629..3f1aa1c 100644 --- a/src/lib/__tests__/rateLimit.test.ts +++ b/src/lib/__tests__/rateLimit.test.ts @@ -97,13 +97,13 @@ describe("RateLimiter", () => { }); describe("getClientIp", () => { - it("returns the right-most x-forwarded-for IP", () => { + it("returns the left-most x-forwarded-for IP", () => { const req = new Request("http://localhost", { headers: { "x-forwarded-for": "5.6.7.8, 9.10.11.12" } }); - expect(getClientIp(req)).toBe("9.10.11.12"); + expect(getClientIp(req)).toBe("5.6.7.8"); }); it("trims whitespace from the selected x-forwarded-for IP", () => { @@ -112,7 +112,7 @@ describe("getClientIp", () => { "x-forwarded-for": " 5.6.7.8 , 9.10.11.12 " } }); - expect(getClientIp(req)).toBe("9.10.11.12"); + expect(getClientIp(req)).toBe("5.6.7.8"); }); it("accepts IPv6 x-forwarded-for values", () => { @@ -124,13 +124,13 @@ describe("getClientIp", () => { expect(getClientIp(req)).toBe("2001:db8::1"); }); - it("does not trust x-real-ip when x-forwarded-for is absent", () => { + it("trusts x-real-ip when available", () => { const req = new Request("http://localhost", { headers: { "x-real-ip": "1.2.3.4" } }); - expect(getClientIp(req)).toBe("unknown"); + expect(getClientIp(req)).toBe("1.2.3.4"); }); it("returns unknown if neither header is present", () => { @@ -147,21 +147,50 @@ describe("getClientIp", () => { expect(getClientIp(req)).toBe("unknown"); }); - it("returns unknown when the right-most x-forwarded-for token is empty", () => { + it("returns unknown when the left-most x-forwarded-for token is empty", () => { const req = new Request("http://localhost", { headers: { - "x-forwarded-for": "5.6.7.8, " + "x-forwarded-for": " , 5.6.7.8" } }); expect(getClientIp(req)).toBe("unknown"); }); - it("returns unknown when the right-most x-forwarded-for token is invalid", () => { + it("returns unknown when the left-most x-forwarded-for token is invalid", () => { const req = new Request("http://localhost", { headers: { - "x-forwarded-for": "5.6.7.8, not-an-ip" + "x-forwarded-for": "not-an-ip, 5.6.7.8" } }); expect(getClientIp(req)).toBe("unknown"); }); }); + + it("prefers x-real-ip over x-forwarded-for when both are present", () => { + const req = new Request("http://localhost", { + headers: { + "x-real-ip": "1.2.3.4", + "x-forwarded-for": "5.6.7.8, 9.10.11.12" + } + }); + expect(getClientIp(req)).toBe("1.2.3.4"); + }); + + it("ignores x-real-ip if it is an invalid IP and falls back to x-forwarded-for", () => { + const req = new Request("http://localhost", { + headers: { + "x-real-ip": "invalid-ip", + "x-forwarded-for": "5.6.7.8, 9.10.11.12" + } + }); + expect(getClientIp(req)).toBe("5.6.7.8"); + }); + + it("ignores x-real-ip if it is an invalid IP and returns unknown if x-forwarded-for is missing", () => { + const req = new Request("http://localhost", { + headers: { + "x-real-ip": "invalid-ip" + } + }); + expect(getClientIp(req)).toBe("unknown"); + }); diff --git a/src/lib/rateLimit.ts b/src/lib/rateLimit.ts index 2f17cf6..795cab7 100644 --- a/src/lib/rateLimit.ts +++ b/src/lib/rateLimit.ts @@ -67,10 +67,16 @@ function isValidIp(value: string): boolean { } export function getClientIp(request: Request): string { + const realIp = request.headers.get("x-real-ip"); + if (realIp) { + const trimmedRealIp = realIp.trim(); + if (isValidIp(trimmedRealIp)) return trimmedRealIp; + } + const forwardedFor = request.headers.get("x-forwarded-for"); if (!forwardedFor) return "unknown"; - const proxyObservedIp = forwardedFor.split(",").at(-1)?.trim(); + const proxyObservedIp = forwardedFor.split(",")[0]?.trim(); if (proxyObservedIp && isValidIp(proxyObservedIp)) return proxyObservedIp; return "unknown";