package middleware import ( "net" "net/http" "strconv" "strings" "sync" "time" ) type ipWindowCounter struct { windowStart time.Time lastSeen time.Time count int } type IPRateLimiter struct { mu sync.Mutex limit int window time.Duration ttl time.Duration maxEntries int lastCleanup time.Time entries map[string]*ipWindowCounter } const defaultRateLimiterMaxEntries = 10000 func NewIPRateLimiter(limit int, window time.Duration, ttl time.Duration) *IPRateLimiter { if limit <= 0 { limit = 60 } if window <= 0 { window = time.Minute } if ttl <= 0 { ttl = 10 * time.Minute } return &IPRateLimiter{ limit: limit, window: window, ttl: ttl, maxEntries: defaultRateLimiterMaxEntries, entries: map[string]*ipWindowCounter{}, } } func (m *IPRateLimiter) Handler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if !m.allow(r) { w.Header().Set("Retry-After", strconv.Itoa(int(m.window.Seconds()))) http.Error(w, "rate limit exceeded", http.StatusTooManyRequests) return } next.ServeHTTP(w, r) }) } func (m *IPRateLimiter) allow(r *http.Request) bool { key := clientKey(r) now := time.Now() m.mu.Lock() defer m.mu.Unlock() if m.lastCleanup.IsZero() || now.Sub(m.lastCleanup) >= m.window { m.cleanupLocked(now) m.lastCleanup = now } entry, exists := m.entries[key] if !exists { if len(m.entries) >= m.maxEntries { m.evictOldestLocked() } m.entries[key] = &ipWindowCounter{ windowStart: now, lastSeen: now, count: 1, } return true } if now.Sub(entry.windowStart) >= m.window { entry.windowStart = now entry.lastSeen = now entry.count = 1 return true } entry.lastSeen = now if entry.count >= m.limit { return false } entry.count++ return true } func (m *IPRateLimiter) cleanupLocked(now time.Time) { for key, entry := range m.entries { if now.Sub(entry.lastSeen) > m.ttl { delete(m.entries, key) } } } func (m *IPRateLimiter) evictOldestLocked() { var oldestKey string var oldestTime time.Time first := true for key, entry := range m.entries { if first || entry.lastSeen.Before(oldestTime) { oldestKey = key oldestTime = entry.lastSeen first = false } } if !first { delete(m.entries, oldestKey) } } func clientKey(r *http.Request) string { host := strings.TrimSpace(r.RemoteAddr) if parsedHost, _, err := net.SplitHostPort(host); err == nil { host = parsedHost } if host == "" { host = "unknown" } return host }