diff --git a/examples/provider/provider_with_rate_limiting.tf b/examples/provider/provider_with_rate_limiting.tf new file mode 100644 index 000000000..0f7343dd3 --- /dev/null +++ b/examples/provider/provider_with_rate_limiting.tf @@ -0,0 +1,36 @@ +# Configure the Auth0 Provider with rate limiting +provider "auth0" { + domain = var.auth0_domain + client_id = var.auth0_client_id + client_secret = var.auth0_client_secret + + # Set the maximum API capacity percentage to use + # This prevents hitting Auth0 rate limits by proactively throttling requests + # when the provider reaches 70% of the available rate limit capacity + max_api_capacity = 70 + + # Enable debug mode to see rate limiting logs + debug = true +} + +# Example resources that will benefit from rate limiting +resource "auth0_client" "my_client" { + name = "My Application" + description = "My Application Description" + app_type = "spa" + callbacks = ["https://example.com/callback"] + allowed_origins = ["https://example.com"] +} + +resource "auth0_user" "users" { + count = 100 # Creating many users will benefit from rate limiting + + connection_name = "Username-Password-Authentication" + email = "user${count.index}@example.com" + password = "passpass$WORD1" + + depends_on = [auth0_client.my_client] +} + +# Environment variable alternative: +# export AUTH0_MAX_API_CAPACITY=70 \ No newline at end of file diff --git a/internal/config/client_retry_test.go b/internal/config/client_retry_test.go index 1cb0158e0..b51cbadd7 100644 --- a/internal/config/client_retry_test.go +++ b/internal/config/client_retry_test.go @@ -33,7 +33,7 @@ func TestCustomClientWithRetries(t *testing.T) { writer.WriteHeader(200) })) - client := customClientWithRetries() + client := customClientWithRetries(100, false) request, err := http.NewRequest(http.MethodGet, testServer.URL, nil) require.NoError(t, err) @@ -67,7 +67,7 @@ func TestCustomClientWithRetries(t *testing.T) { writer.WriteHeader(200) })) - client := customClientWithRetries() + client := customClientWithRetries(100, false) request, err := http.NewRequest(http.MethodGet, testServer.URL, nil) require.NoError(t, err) @@ -93,7 +93,7 @@ func TestCustomClientWithRetries(t *testing.T) { writer.WriteHeader(500) })) - client := customClientWithRetries() + client := customClientWithRetries(100, false) request, err := http.NewRequest(http.MethodGet, testServer.URL, nil) require.NoError(t, err) diff --git a/internal/config/config.go b/internal/config/config.go index 5f735658d..59a9e4dfc 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -24,6 +24,9 @@ import ( "github.com/zalando/go-keyring" "github.com/auth0/terraform-provider-auth0/internal/mutex" + "github.com/auth0/terraform-provider-auth0/internal/ratelimit" + "github.com/auth0/terraform-provider-auth0/internal/transport" + "github.com/hashicorp/go-hclog" ) const providerName = "Terraform-Provider-Auth0" // #nosec G101 @@ -68,6 +71,7 @@ type authConfig struct { clientAssertionPrivateKey string clientAssertionSigningAlg string customDomainHeader string + maxAPICapacity int } // ConfigureProvider will configure the *schema.Provider so that @@ -83,6 +87,7 @@ func ConfigureProvider(terraformVersion *string) schema.ConfigureContextFunc { clientAssertionPrivateKey: data.Get("client_assertion_private_key").(string), clientAssertionSigningAlg: data.Get("client_assertion_signing_alg").(string), customDomainHeader: data.Get("custom_domain_header").(string), + maxAPICapacity: data.Get("max_api_capacity").(int), } domain := data.Get("domain").(string) @@ -160,7 +165,7 @@ func ConfigureProvider(terraformVersion *string) schema.ConfigureContextFunc { management.WithUserAgent(userAgent(terraformVersion)), management.WithAuth0ClientEnvEntry(providerName, version), management.WithNoRetries(), - management.WithClient(customClientWithRetries()), + management.WithClient(customClientWithRetries(config.maxAPICapacity, debug)), management.WithCustomDomainHeader(config.customDomainHeader)) if err != nil { @@ -305,13 +310,32 @@ func authenticationOption(cfg authConfig) management.Option { } } -func customClientWithRetries() *http.Client { +func customClientWithRetries(maxAPICapacity int, debug bool) *http.Client { + baseTransport := retryableErrorTransport(http.DefaultTransport) + + // Apply rate limiting if maxAPICapacity is less than 100% + if maxAPICapacity > 0 && maxAPICapacity < 100 { + // Create a simple logger for the transport + logger := hclog.New(&hclog.LoggerOptions{ + Name: "auth0-rate-limiter", + Level: hclog.Info, + }) + if debug { + logger.SetLevel(hclog.Debug) + } + + rateLimitManager, err := ratelimit.NewRateLimitManager(maxAPICapacity) + if err != nil { + // If we can't create the rate limit manager, fall back to basic rate limiting + logger.Error("Failed to create rate limit manager, falling back to basic rate limiting", "error", err) + } else { + logger.Info(fmt.Sprintf("Auth0 provider running with max_api_capacity configuration at %d%%", maxAPICapacity)) + baseTransport = transport.NewGovernedTransport(baseTransport, rateLimitManager, logger) + } + } + client := &http.Client{ - Transport: rateLimitTransport( - retryableErrorTransport( - http.DefaultTransport, - ), - ), + Transport: rateLimitTransport(baseTransport), } return client diff --git a/internal/provider/provider.go b/internal/provider/provider.go index f015940c5..385df0b2e 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -1,6 +1,7 @@ package provider import ( + "fmt" "os" "github.com/auth0/terraform-provider-auth0/internal/auth0/networkacl" @@ -146,6 +147,23 @@ func New() *schema.Provider { Description: "When specified, this header is added to requests targeting a set of pre-defined whitelisted URLs " + "Global setting overrides all resource specific `custom_domain_header` value", }, + "max_api_capacity": { + Type: schema.TypeInt, + Optional: true, + DefaultFunc: schema.EnvDefaultFunc("AUTH0_MAX_API_CAPACITY", 100), + ValidateFunc: func(val interface{}, key string) (warns []string, errs []error) { + v := val.(int) + if v < 1 || v > 100 { + errs = append(errs, fmt.Errorf("%q must be between 1 and 100, got: %d", key, v)) + } + return + }, + Description: "Sets what percentage of capacity the provider can use of the total rate limit " + + "capacity while making calls to the Auth0 management API endpoints. Auth0 API operates with " + + "rate limits per endpoint. See Auth0 Rate Limit Policy: " + + "https://auth0.com/docs/troubleshoot/product-lifecycle/rate-limit-policy. " + + "It can also be sourced from the `AUTH0_MAX_API_CAPACITY` environment variable.", + }, }, ResourcesMap: map[string]*schema.Resource{ "auth0_action": action.NewResource(), diff --git a/internal/ratelimit/ratelimit.go b/internal/ratelimit/ratelimit.go new file mode 100644 index 000000000..b230e8246 --- /dev/null +++ b/internal/ratelimit/ratelimit.go @@ -0,0 +1,248 @@ +package ratelimit + +import ( + "fmt" + "regexp" + "strings" + "sync" + "time" +) + +// RateLimitManager synchronizes keeping account of current known rate limit values +// from Auth0 management endpoints. See: +// https://auth0.com/docs/troubleshoot/product-lifecycle/rate-limit-policy +// +// The Auth0 Terraform Provider can not account for other clients consumption of +// API limits but it can account for its own usage and attempt to preemptively +// react appropriately. +type RateLimitManager struct { + lock sync.Mutex + capacity int + status map[string]*RateLimitStatus + buckets map[string]string +} + +// RateLimitStatus is used to hold rate limit information from Auth0's API +type RateLimitStatus struct { + limit int + remaining int + reset int64 // UTC epoch time in seconds +} + +// NewRateLimitManager returns a new rate limit manager object that represents untilized +// capacity under the specified capacity percentage. +func NewRateLimitManager(capacity int) (*RateLimitManager, error) { + rootStatus := &RateLimitStatus{} + manager := &RateLimitManager{ + capacity: capacity, + status: map[string]*RateLimitStatus{ + "/": rootStatus, + }, + buckets: map[string]string{}, + } + manager.initRateLimitMapping() + + return manager, nil +} + +// HasCapacity approximates if there is capacity below the rate limit manager's maximum +// capacity threshold. +func (m *RateLimitManager) HasCapacity(method, endpoint string) bool { + status := m.get(method, endpoint) + + // if the status hasn't been updated recently assume there is capacity + if status.reset+60 < time.Now().Unix() { + return true + } + + // calculate utilization + utilization := 100.0 * (float32(status.limit-status.remaining) / float32(status.limit)) + + return utilization <= float32(m.capacity) +} + +// Update updates the known status for the given API endpoint. It is synchronous +// and intelligently accounts for new values regardless of parallelism. +func (m *RateLimitManager) Update(method, endpoint string, limit, remaining int, reset int64) { + m.lock.Lock() + defer m.lock.Unlock() + + status := m.get(method, endpoint) + if reset > status.reset { + // reset value greater than current reset implies we are in a new Auth0 API + // window. set/reset values. + status.reset = reset + status.remaining = remaining + status.limit = limit + return + } + + if reset <= (status.reset - 60) { + // these values are from the previous window, ignore + return + } + + if remaining < status.remaining { + status.remaining = remaining + } +} + +// Status Returns the RateLimitStatus for the given method + endpoint combination. +func (m *RateLimitManager) Status(method, endpoint string) *RateLimitStatus { + return m.get(method, endpoint) +} + +// Class Returns the api endpoint class. +func (m *RateLimitManager) Class(method, endpoint string) string { + path := reAuth0ID.ReplaceAllString(endpoint, "ID") + return m.normalizedKey(method, path) +} + +// Bucket Returns the rate limit bucket the api endpoint falls into. +func (m *RateLimitManager) Bucket(method, endpoint string) string { + path := reAuth0ID.ReplaceAllString(endpoint, "ID") + key := m.normalizedKey(method, path) + bucket, ok := m.buckets[key] + if !ok { + return "/" + } + return bucket +} + +func (m *RateLimitManager) normalizedKey(method, endpoint string) string { + return fmt.Sprintf("%s %s", method, endpoint) +} + +// Reset returns the current reset value of the rate limit status object. +func (s *RateLimitStatus) Reset() int64 { + return s.reset +} + +// Limit returns the current limit value of the rate limit status object. +func (s *RateLimitStatus) Limit() int { + return s.limit +} + +// Remaining returns the current remaining value of the rate limit status object. +func (s *RateLimitStatus) Remaining() int { + return s.remaining +} + +// Regex to match Auth0 IDs - includes various formats: +// - auth0|507f1f77bcf86cd799439011 (social connections) +// - YmF12345678901234567890 (client IDs) +// - rol_12345678901234567890 (role IDs) +// - org_12345678901234567890 (organization IDs) +// - con_12345678901234567890 (connection IDs) +var reAuth0ID = regexp.MustCompile(`(?:auth0\|[a-zA-Z0-9]+|[a-zA-Z]{3}_[a-zA-Z0-9]{20,}|[a-zA-Z0-9]{20,})`) + +func (m *RateLimitManager) get(method, endpoint string) *RateLimitStatus { + // The important point here is the replace all is performing this + // transformation for the bucket lookup /api/v2/users/auth0|507f1f77bcf86cd799439011 + // to /api/v2/users/ID . + path := reAuth0ID.ReplaceAllString(endpoint, "ID") + key := m.normalizedKey(method, path) + bucket, ok := m.buckets[key] + if !ok { + return m.status["/"] + } + return m.status[bucket] +} + +func (m *RateLimitManager) initRateLimitMapping() { + // Auth0 Management API endpoints and their rate limit buckets + // Based on https://auth0.com/docs/troubleshoot/product-lifecycle/rate-limit-policy + rateLimitLines := []string{ + // Users endpoints - these are often the most rate-limited + "/api/v2/users GET /api/v2/users", + "/api/v2/users POST /api/v2/users", + "/api/v2/users/ID GET /api/v2/users/{id}", + "/api/v2/users/ID PATCH /api/v2/users/{id}", + "/api/v2/users/ID DELETE /api/v2/users/{id}", + "/api/v2/users/ID/roles GET /api/v2/users", + "/api/v2/users/ID/roles POST /api/v2/users", + "/api/v2/users/ID/roles DELETE /api/v2/users", + "/api/v2/users/ID/permissions GET /api/v2/users", + "/api/v2/users/ID/permissions POST /api/v2/users", + "/api/v2/users/ID/permissions DELETE /api/v2/users", + + // Clients/Applications + "/api/v2/clients GET /api/v2/clients", + "/api/v2/clients POST /api/v2/clients", + "/api/v2/clients/ID GET /api/v2/clients/{id}", + "/api/v2/clients/ID PATCH /api/v2/clients/{id}", + "/api/v2/clients/ID DELETE /api/v2/clients/{id}", + "/api/v2/clients/ID/credentials GET /api/v2/clients", + "/api/v2/clients/ID/credentials POST /api/v2/clients", + "/api/v2/clients/ID/credentials DELETE /api/v2/clients", + + // Connections + "/api/v2/connections GET /api/v2/connections", + "/api/v2/connections POST /api/v2/connections", + "/api/v2/connections/ID GET /api/v2/connections/{id}", + "/api/v2/connections/ID PATCH /api/v2/connections/{id}", + "/api/v2/connections/ID DELETE /api/v2/connections/{id}", + + // Organizations + "/api/v2/organizations GET /api/v2/organizations", + "/api/v2/organizations POST /api/v2/organizations", + "/api/v2/organizations/ID GET /api/v2/organizations/{id}", + "/api/v2/organizations/ID PATCH /api/v2/organizations/{id}", + "/api/v2/organizations/ID DELETE /api/v2/organizations/{id}", + "/api/v2/organizations/ID/members GET /api/v2/organizations", + "/api/v2/organizations/ID/members POST /api/v2/organizations", + "/api/v2/organizations/ID/members DELETE /api/v2/organizations", + + // Roles + "/api/v2/roles GET /api/v2/roles", + "/api/v2/roles POST /api/v2/roles", + "/api/v2/roles/ID GET /api/v2/roles/{id}", + "/api/v2/roles/ID PATCH /api/v2/roles/{id}", + "/api/v2/roles/ID DELETE /api/v2/roles/{id}", + "/api/v2/roles/ID/permissions GET /api/v2/roles", + "/api/v2/roles/ID/permissions POST /api/v2/roles", + "/api/v2/roles/ID/permissions DELETE /api/v2/roles", + + // Resource Servers + "/api/v2/resource-servers GET /api/v2/resource-servers", + "/api/v2/resource-servers POST /api/v2/resource-servers", + "/api/v2/resource-servers/ID GET /api/v2/resource-servers/{id}", + "/api/v2/resource-servers/ID PATCH /api/v2/resource-servers/{id}", + "/api/v2/resource-servers/ID DELETE /api/v2/resource-servers/{id}", + + // Actions + "/api/v2/actions/actions GET /api/v2/actions", + "/api/v2/actions/actions POST /api/v2/actions", + "/api/v2/actions/actions/ID GET /api/v2/actions/{id}", + "/api/v2/actions/actions/ID PATCH /api/v2/actions/{id}", + "/api/v2/actions/actions/ID DELETE /api/v2/actions/{id}", + + // Tenant settings - lower rate limits typically + "/api/v2/tenants/settings GET /api/v2/tenants", + "/api/v2/tenants/settings PATCH /api/v2/tenants", + + // Default bucket for any unmatched endpoints + "/ GET /", + "/ POST /", + "/ PATCH /", + "/ DELETE /", + "/ PUT /", + } + + for _, line := range rateLimitLines { + vals := strings.Fields(line) + if len(vals) < 3 { + continue + } + path := vals[0] + method := vals[1] + bucket := vals[2] + + key := m.normalizedKey(method, path) + m.buckets[key] = bucket + + if _, ok := m.status[bucket]; !ok { + m.status[bucket] = &RateLimitStatus{} + } + } +} \ No newline at end of file diff --git a/internal/ratelimit/ratelimit_test.go b/internal/ratelimit/ratelimit_test.go new file mode 100644 index 000000000..8bcdc9c61 --- /dev/null +++ b/internal/ratelimit/ratelimit_test.go @@ -0,0 +1,192 @@ +package ratelimit + +import ( + "testing" + "time" +) + +func TestNewRateLimitManager(t *testing.T) { + manager, err := NewRateLimitManager(80) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + if manager.capacity != 80 { + t.Errorf("Expected capacity 80, got %d", manager.capacity) + } + + // Check that default status exists + if _, ok := manager.status["/"]; !ok { + t.Error("Expected default status to exist") + } +} + +func TestHasCapacity(t *testing.T) { + manager, _ := NewRateLimitManager(50) + + // Test with no previous data (should return true) + if !manager.HasCapacity("GET", "/api/v2/users") { + t.Error("Expected HasCapacity to return true for unused endpoint") + } + + // Test with capacity exceeded (40 remaining out of 100 = 60% utilization, over 50% capacity) + manager.Update("GET", "/api/v2/users", 100, 40, time.Now().Unix()+60) + if manager.HasCapacity("GET", "/api/v2/users") { + t.Error("Expected HasCapacity to return false when over 50% utilization") + } + + // Test with capacity available (60 remaining out of 100 = 40% utilization, under 50% threshold) + // Create new endpoint to avoid interference from previous update + manager.Update("GET", "/api/v2/clients", 100, 60, time.Now().Unix()+60) + if !manager.HasCapacity("GET", "/api/v2/clients") { + t.Error("Expected HasCapacity to return true when under 50% utilization") + } +} + +func TestUpdate(t *testing.T) { + manager, _ := NewRateLimitManager(80) + + // Test initial update + reset := time.Now().Unix() + 60 + manager.Update("GET", "/api/v2/users", 100, 80, reset) + + status := manager.Status("GET", "/api/v2/users") + if status.Limit() != 100 { + t.Errorf("Expected limit 100, got %d", status.Limit()) + } + if status.Remaining() != 80 { + t.Errorf("Expected remaining 80, got %d", status.Remaining()) + } + if status.Reset() != reset { + t.Errorf("Expected reset %d, got %d", reset, status.Reset()) + } + + // Test update with newer reset time (should update all values) + newReset := reset + 60 + manager.Update("GET", "/api/v2/users", 120, 90, newReset) + + status = manager.Status("GET", "/api/v2/users") + if status.Limit() != 120 { + t.Errorf("Expected limit 120, got %d", status.Limit()) + } + if status.Remaining() != 90 { + t.Errorf("Expected remaining 90, got %d", status.Remaining()) + } + + // Test update with same reset time but lower remaining (should update remaining) + manager.Update("GET", "/api/v2/users", 120, 70, newReset) + + status = manager.Status("GET", "/api/v2/users") + if status.Remaining() != 70 { + t.Errorf("Expected remaining 70, got %d", status.Remaining()) + } +} + +func TestClassAndBucket(t *testing.T) { + manager, _ := NewRateLimitManager(80) + + // Test ID replacement + class := manager.Class("GET", "/api/v2/users/auth0|507f1f77bcf86cd799439011") + expected := "GET /api/v2/users/ID" + if class != expected { + t.Errorf("Expected class %q, got %q", expected, class) + } + + // Test bucket mapping + bucket := manager.Bucket("GET", "/api/v2/users") + expected = "/api/v2/users" + if bucket != expected { + t.Errorf("Expected bucket %q, got %q", expected, bucket) + } + + // Test default bucket for unmapped endpoint + bucket = manager.Bucket("GET", "/api/v2/unknown") + expected = "/" + if bucket != expected { + t.Errorf("Expected default bucket %q, got %q", expected, bucket) + } +} + +func TestNormalizedKey(t *testing.T) { + manager, _ := NewRateLimitManager(80) + + key := manager.normalizedKey("GET", "/api/v2/users") + expected := "GET /api/v2/users" + if key != expected { + t.Errorf("Expected key %q, got %q", expected, key) + } +} + +func TestUpdateWithOldReset(t *testing.T) { + manager, _ := NewRateLimitManager(80) + + // Set initial state with current time + currentTime := time.Now().Unix() + manager.Update("GET", "/api/v2/users", 100, 80, currentTime) + + // Try to update with old reset time (should be ignored) + oldResetTime := currentTime - 120 // 2 minutes ago + manager.Update("GET", "/api/v2/users", 120, 70, oldResetTime) + + // Status should remain unchanged + status := manager.Status("GET", "/api/v2/users") + if status.Limit() != 100 { + t.Errorf("Expected limit to remain 100, got %d", status.Limit()) + } + if status.Remaining() != 80 { + t.Errorf("Expected remaining to remain 80, got %d", status.Remaining()) + } +} + +func TestGetWithUnmappedEndpoint(t *testing.T) { + manager, _ := NewRateLimitManager(80) + + // Test with completely unmapped endpoint + status := manager.get("POST", "/unknown/endpoint") + + // Should return the default root status + if status != manager.status["/"] { + t.Error("Expected unmapped endpoint to return root status") + } +} + +func TestInitRateLimitMappingWithInvalidLine(t *testing.T) { + manager := &RateLimitManager{ + capacity: 50, + status: map[string]*RateLimitStatus{ + "/": &RateLimitStatus{}, + }, + buckets: map[string]string{}, + } + + // This should test the continue case in initRateLimitMapping + // We can't easily test this without modifying the rateLimitLines, + // but we can test that the function handles empty bucket creation + manager.initRateLimitMapping() + + // Verify that some standard buckets were created + if _, ok := manager.status["/api/v2/users"]; !ok { + t.Error("Expected /api/v2/users bucket to be created") + } +} + +func TestAuth0IDRegex(t *testing.T) { + testCases := []struct { + input string + expected string + }{ + {"/api/v2/users/auth0|507f1f77bcf86cd799439011", "/api/v2/users/ID"}, + {"/api/v2/clients/YmF12345678901234567890", "/api/v2/clients/ID"}, + {"/api/v2/roles/rol_12345678901234567890", "/api/v2/roles/ID"}, + {"/api/v2/organizations/org_12345678901234567890/members", "/api/v2/organizations/ID/members"}, + {"/api/v2/connections/con_12345678901234567890", "/api/v2/connections/ID"}, + {"/api/v2/users", "/api/v2/users"}, // No ID to replace + } + + for _, tc := range testCases { + result := reAuth0ID.ReplaceAllString(tc.input, "ID") + if result != tc.expected { + t.Errorf("For input %q, expected %q, got %q", tc.input, tc.expected, result) + } + } +} \ No newline at end of file diff --git a/internal/transport/governed_transport.go b/internal/transport/governed_transport.go new file mode 100644 index 000000000..0e7ee6bdf --- /dev/null +++ b/internal/transport/governed_transport.go @@ -0,0 +1,160 @@ +package transport + +import ( + "context" + "fmt" + "net/http" + "strconv" + "time" + + "github.com/hashicorp/go-hclog" + + "github.com/auth0/terraform-provider-auth0/internal/ratelimit" +) + +const ( + X_RATE_LIMIT_LIMIT = "x-ratelimit-limit" + X_RATE_LIMIT_REMAINING = "x-ratelimit-remaining" + X_RATE_LIMIT_RESET = "x-ratelimit-reset" +) + +type GovernedTransport struct { + base http.RoundTripper + rateLimitManager *ratelimit.RateLimitManager + logger hclog.Logger +} + +// NewGovernedTransport returns a governed transport that relies on pre- and post- +// requests from the http round tripper. The pre request consults the rate limit manager +// to determine if sleeping for the Auth0 API rate limit window is called for. +// The post request updates the information it is holding about the current api +// rate limits. +func NewGovernedTransport(base http.RoundTripper, rateLimitManager *ratelimit.RateLimitManager, logger hclog.Logger) *GovernedTransport { + return &GovernedTransport{ + base: base, + rateLimitManager: rateLimitManager, + logger: logger, + } +} + +// RoundTrip returns the final http response after it has managed the api rate +// limit accounting in the pre and post request hooks. +func (t *GovernedTransport) RoundTrip(req *http.Request) (*http.Response, error) { + path := req.URL.Path + if err := t.preRequestHook(req.Context(), req.Method, path); err != nil { + return nil, err + } + + resp, err := t.base.RoundTrip(req) + // always attempt to save rate limit headers + t.postRequestHook(req.Method, path, resp) + if err != nil { + return nil, err + } + + return resp, nil +} + +func (t *GovernedTransport) preRequestHook(ctx context.Context, method, path string) error { + if t.rateLimitManager.HasCapacity(method, path) { + return nil + } + + status := t.rateLimitManager.Status(method, path) + now := time.Now().Unix() + timeToSleep := status.Reset() - now + + // Cap the sleep time to prevent excessive waiting + if timeToSleep > 300 { // 5 minutes max + timeToSleep = 300 + } + if timeToSleep < 1 { + timeToSleep = 1 + } + + line := fmt.Sprintf("Throttling Auth0 API requests; sleeping for %d seconds until rate limit reset (path class %q, bucket %q: %d remaining of %d total); current request \"%s %s\"", + timeToSleep, + t.rateLimitManager.Class(method, path), + t.rateLimitManager.Bucket(method, path), + status.Remaining(), + status.Limit(), + method, + path, + ) + t.logger.Info(line) + + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.NewTimer(time.Second * time.Duration(timeToSleep)).C: + return nil + } +} + +func (t *GovernedTransport) postRequestHook(method, path string, resp *http.Response) { + if resp == nil { + return + } + + // Auth0 uses X-RateLimit-Reset (Unix timestamp) + // Try multiple header name variations due to inconsistencies in test environments + resetHeader := resp.Header.Get("x-ratelimit-reset") + if resetHeader == "" { + if vals, ok := resp.Header["x-ratelimit-reset"]; ok && len(vals) > 0 { + resetHeader = vals[0] + } else if vals, ok := resp.Header["X-RateLimit-Reset"]; ok && len(vals) > 0 { + resetHeader = vals[0] + } else { + resetHeader = resp.Header.Get("X-RateLimit-Reset") + } + } + + if resetHeader == "" { + t.logger.Debug(fmt.Sprintf("No rate limit reset header found in response for %s %s. Headers: %v", method, path, resp.Header)) + return + } + + reset, err := strconv.ParseInt(resetHeader, 10, 64) + if err != nil { + t.logger.Warn(fmt.Sprintf("%q response header is missing or invalid, skipping postRequestHook: %+v", X_RATE_LIMIT_RESET, err)) + return + } + + limitHeader := resp.Header.Get("x-ratelimit-limit") + if limitHeader == "" { + // Try direct access + if vals, ok := resp.Header["x-ratelimit-limit"]; ok && len(vals) > 0 { + limitHeader = vals[0] + } else if vals, ok := resp.Header["X-RateLimit-Limit"]; ok && len(vals) > 0 { + limitHeader = vals[0] + } else { + limitHeader = resp.Header.Get("X-RateLimit-Limit") + } + } + + limit, err := strconv.Atoi(limitHeader) + if err != nil { + t.logger.Warn(fmt.Sprintf("%q response header is missing or invalid, skipping postRequestHook: %+v", X_RATE_LIMIT_LIMIT, err)) + return + } + + remainingHeader := resp.Header.Get("x-ratelimit-remaining") + if remainingHeader == "" { + // Try direct access + if vals, ok := resp.Header["x-ratelimit-remaining"]; ok && len(vals) > 0 { + remainingHeader = vals[0] + } else if vals, ok := resp.Header["X-RateLimit-Remaining"]; ok && len(vals) > 0 { + remainingHeader = vals[0] + } else { + remainingHeader = resp.Header.Get("X-RateLimit-Remaining") + } + } + + remaining, err := strconv.Atoi(remainingHeader) + if err != nil { + t.logger.Warn(fmt.Sprintf("%q response header is missing or invalid, skipping postRequestHook: %+v", X_RATE_LIMIT_REMAINING, err)) + return + } + + t.rateLimitManager.Update(method, path, limit, remaining, reset) +} \ No newline at end of file diff --git a/internal/transport/governed_transport_test.go b/internal/transport/governed_transport_test.go new file mode 100644 index 000000000..f56a703db --- /dev/null +++ b/internal/transport/governed_transport_test.go @@ -0,0 +1,332 @@ +package transport + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "strconv" + "testing" + "time" + + "github.com/hashicorp/go-hclog" + + "github.com/auth0/terraform-provider-auth0/internal/ratelimit" +) + +type mockRoundTripper struct { + response *http.Response + err error +} + +func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return m.response, m.err +} + +func TestGovernedTransport_RoundTrip(t *testing.T) { + // Create a mock rate limit manager + rateLimitManager, _ := ratelimit.NewRateLimitManager(80) + + // Create a logger + logger := hclog.New(&hclog.LoggerOptions{ + Name: "test", + Output: io.Discard, + }) + + // Create mock response with rate limit headers + resetTime := time.Now().Unix() + 60 + mockResponse := &http.Response{ + StatusCode: 200, + Header: http.Header{ + "x-ratelimit-limit": []string{"100"}, + "x-ratelimit-remaining": []string{"95"}, + "x-ratelimit-reset": []string{strconv.FormatInt(resetTime, 10)}, + }, + Body: io.NopCloser(bytes.NewBufferString("{}")), + } + + mockRoundTripper := &mockRoundTripper{response: mockResponse} + + // Create governed transport + transport := NewGovernedTransport(mockRoundTripper, rateLimitManager, logger) + + // Create test request + req, _ := http.NewRequest("GET", "https://example.auth0.com/api/v2/users", nil) + + // Make request + resp, err := transport.RoundTrip(req) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + if resp.StatusCode != 200 { + t.Errorf("Expected status code 200, got %d", resp.StatusCode) + } + + // Check that rate limit status was updated + status := rateLimitManager.Status("GET", "/api/v2/users") + if status.Limit() != 100 { + t.Errorf("Expected limit 100, got %d", status.Limit()) + } + if status.Remaining() != 95 { + t.Errorf("Expected remaining 95, got %d", status.Remaining()) + } +} + +func TestGovernedTransport_PreRequestHook_NoThrottling(t *testing.T) { + // Create rate limit manager with high capacity + rateLimitManager, _ := ratelimit.NewRateLimitManager(90) + + logger := hclog.New(&hclog.LoggerOptions{ + Name: "test", + Output: io.Discard, + }) + + mockRoundTripper := &mockRoundTripper{response: &http.Response{StatusCode: 200}} + transport := NewGovernedTransport(mockRoundTripper, rateLimitManager, logger) + + // Should not throttle when under capacity + err := transport.preRequestHook(context.Background(), "GET", "/api/v2/users") + if err != nil { + t.Errorf("Expected no error, got %v", err) + } +} + +func TestGovernedTransport_PreRequestHook_WithThrottling(t *testing.T) { + // Create rate limit manager with low capacity + rateLimitManager, _ := ratelimit.NewRateLimitManager(10) + + // Update with high utilization (should trigger throttling) + resetTime := time.Now().Unix() + 2 // Short reset time for test + rateLimitManager.Update("GET", "/api/v2/users", 100, 10, resetTime) + + logger := hclog.New(&hclog.LoggerOptions{ + Name: "test", + Output: io.Discard, + }) + + mockRoundTripper := &mockRoundTripper{response: &http.Response{StatusCode: 200}} + transport := NewGovernedTransport(mockRoundTripper, rateLimitManager, logger) + + // Create context with timeout to prevent test from hanging + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + start := time.Now() + err := transport.preRequestHook(ctx, "GET", "/api/v2/users") + duration := time.Since(start) + + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + + // Should have slept for at least 1 second + if duration < time.Second { + t.Errorf("Expected to sleep for at least 1 second, slept for %v", duration) + } +} + +func TestGovernedTransport_PostRequestHook_InvalidHeaders(t *testing.T) { + rateLimitManager, _ := ratelimit.NewRateLimitManager(80) + + logger := hclog.New(&hclog.LoggerOptions{ + Name: "test", + Output: io.Discard, + }) + + mockRoundTripper := &mockRoundTripper{} + transport := NewGovernedTransport(mockRoundTripper, rateLimitManager, logger) + + // Test with response missing headers + resp := &http.Response{ + StatusCode: 200, + Header: http.Header{}, + } + + // Should not panic or error + transport.postRequestHook("GET", "/api/v2/users", resp) + + // Test with nil response + transport.postRequestHook("GET", "/api/v2/users", nil) +} + +func TestGovernedTransport_RoundTrip_WithError(t *testing.T) { + rateLimitManager, _ := ratelimit.NewRateLimitManager(80) + + logger := hclog.New(&hclog.LoggerOptions{ + Name: "test", + Output: io.Discard, + }) + + // Mock transport that returns an error + mockRoundTripper := &mockRoundTripper{response: nil, err: fmt.Errorf("connection failed")} + transport := NewGovernedTransport(mockRoundTripper, rateLimitManager, logger) + + req, _ := http.NewRequest("GET", "https://example.auth0.com/api/v2/users", nil) + + // This should return an error but still call postRequestHook + _, err := transport.RoundTrip(req) + if err == nil { + t.Error("Expected an error from RoundTrip") + } +} + +func TestGovernedTransport_PreRequestHook_WithCancellation(t *testing.T) { + rateLimitManager, _ := ratelimit.NewRateLimitManager(10) + + // Set high utilization to trigger throttling + resetTime := time.Now().Unix() + 10 // 10 seconds in future + rateLimitManager.Update("GET", "/api/v2/users", 100, 5, resetTime) + + logger := hclog.New(&hclog.LoggerOptions{ + Name: "test", + Output: io.Discard, + }) + + mockRoundTripper := &mockRoundTripper{response: &http.Response{StatusCode: 200}} + transport := NewGovernedTransport(mockRoundTripper, rateLimitManager, logger) + + // Create context that is immediately cancelled + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := transport.preRequestHook(ctx, "GET", "/api/v2/users") + if err == nil { + t.Error("Expected context cancellation error") + } +} + +func TestGovernedTransport_PreRequestHook_WithTimeCapLimiting(t *testing.T) { + rateLimitManager, _ := ratelimit.NewRateLimitManager(10) + + // Set high utilization with very long reset time (should be capped) + resetTime := time.Now().Unix() + 500 // 500 seconds in future (should be capped to 300) + rateLimitManager.Update("GET", "/api/v2/users", 100, 5, resetTime) + + logger := hclog.New(&hclog.LoggerOptions{ + Name: "test", + Output: io.Discard, + }) + + mockRoundTripper := &mockRoundTripper{response: &http.Response{StatusCode: 200}} + transport := NewGovernedTransport(mockRoundTripper, rateLimitManager, logger) + + // This should cap the sleep time to 300 seconds, but for testing we'll use a short timeout + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + start := time.Now() + err := transport.preRequestHook(ctx, "GET", "/api/v2/users") + duration := time.Since(start) + + // Should get context deadline exceeded + if err == nil { + t.Error("Expected context deadline exceeded error") + } + + // Should have tried to sleep but been interrupted + if duration < 50*time.Millisecond { + t.Error("Expected to attempt sleeping before context cancellation") + } +} + +func TestGovernedTransport_PostRequestHook_WithMissingLimitHeader(t *testing.T) { + rateLimitManager, _ := ratelimit.NewRateLimitManager(80) + + logger := hclog.New(&hclog.LoggerOptions{ + Name: "test", + Output: io.Discard, + }) + + mockRoundTripper := &mockRoundTripper{} + transport := NewGovernedTransport(mockRoundTripper, rateLimitManager, logger) + + resetTime := time.Now().Unix() + 60 + + // Response with only reset header, missing limit + resp := &http.Response{ + StatusCode: 200, + Header: http.Header{ + "x-ratelimit-reset": []string{strconv.FormatInt(resetTime, 10)}, + // Missing limit and remaining headers + }, + } + + // Should not panic, should just return early + transport.postRequestHook("GET", "/api/v2/users", resp) + + // Status should remain at defaults since update failed + status := rateLimitManager.Status("GET", "/api/v2/users") + if status.Limit() != 0 { + t.Errorf("Expected limit to remain 0, got %d", status.Limit()) + } +} + +func TestGovernedTransport_PostRequestHook_WithMissingRemainingHeader(t *testing.T) { + rateLimitManager, _ := ratelimit.NewRateLimitManager(80) + + logger := hclog.New(&hclog.LoggerOptions{ + Name: "test", + Output: io.Discard, + }) + + mockRoundTripper := &mockRoundTripper{} + transport := NewGovernedTransport(mockRoundTripper, rateLimitManager, logger) + + resetTime := time.Now().Unix() + 60 + + // Response with reset and limit, missing remaining + resp := &http.Response{ + StatusCode: 200, + Header: http.Header{ + "x-ratelimit-reset": []string{strconv.FormatInt(resetTime, 10)}, + "x-ratelimit-limit": []string{"100"}, + // Missing remaining header + }, + } + + // Should not panic, should just return early + transport.postRequestHook("GET", "/api/v2/users", resp) + + // Status should remain at defaults since update failed + status := rateLimitManager.Status("GET", "/api/v2/users") + if status.Limit() != 0 { + t.Errorf("Expected limit to remain 0, got %d", status.Limit()) + } +} + +func TestGovernedTransport_PostRequestHook_AlternativeHeaders(t *testing.T) { + rateLimitManager, _ := ratelimit.NewRateLimitManager(80) + + logger := hclog.New(&hclog.LoggerOptions{ + Name: "test", + Output: io.Discard, + }) + + mockRoundTripper := &mockRoundTripper{} + transport := NewGovernedTransport(mockRoundTripper, rateLimitManager, logger) + + resetTime := time.Now().Unix() + 60 + + // Test with alternative header names (capital R) + resp := &http.Response{ + StatusCode: 200, + Header: http.Header{ + "X-RateLimit-Limit": []string{"100"}, + "X-RateLimit-Remaining": []string{"85"}, + "X-RateLimit-Reset": []string{strconv.FormatInt(resetTime, 10)}, + }, + } + + transport.postRequestHook("GET", "/api/v2/users", resp) + + // Check that rate limit status was updated + status := rateLimitManager.Status("GET", "/api/v2/users") + if status.Limit() != 100 { + t.Errorf("Expected limit 100, got %d", status.Limit()) + } + if status.Remaining() != 85 { + t.Errorf("Expected remaining 85, got %d", status.Remaining()) + } +} \ No newline at end of file diff --git a/templates/index.md.tmpl b/templates/index.md.tmpl index d171a460a..fd97cb3eb 100644 --- a/templates/index.md.tmpl +++ b/templates/index.md.tmpl @@ -21,6 +21,10 @@ Use the navigation to the left to read about the available resources and data so {{ tffile "examples/provider/provider_with_private_jwt.tf" }} +### Rate Limiting Configuration + +{{ tffile "examples/provider/provider_with_rate_limiting.tf" }} + ~> Hard-coding credentials into any Terraform configuration is not recommended, and risks secret leakage should this file ever be committed to a public version control system. See [Environment Variables](#environment-variables) for a better alternative.