Files
go-core/middleware/rate_limit.go
2026-03-01 03:04:10 +01:00

134 lines
2.5 KiB
Go

package middleware
import (
"net"
"net/http"
"strconv"
"strings"
"sync"
"time"
)
type ipWindowCounter struct {
windowStart time.Time
lastSeen time.Time
count int
}
type IPRateLimiter struct {
mu sync.Mutex
limit int
window time.Duration
ttl time.Duration
maxEntries int
lastCleanup time.Time
entries map[string]*ipWindowCounter
}
const defaultRateLimiterMaxEntries = 10000
func NewIPRateLimiter(limit int, window time.Duration, ttl time.Duration) *IPRateLimiter {
if limit <= 0 {
limit = 60
}
if window <= 0 {
window = time.Minute
}
if ttl <= 0 {
ttl = 10 * time.Minute
}
return &IPRateLimiter{
limit: limit,
window: window,
ttl: ttl,
maxEntries: defaultRateLimiterMaxEntries,
entries: map[string]*ipWindowCounter{},
}
}
func (m *IPRateLimiter) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !m.allow(r) {
w.Header().Set("Retry-After", strconv.Itoa(int(m.window.Seconds())))
http.Error(w, "rate limit exceeded", http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
func (m *IPRateLimiter) allow(r *http.Request) bool {
key := clientKey(r)
now := time.Now()
m.mu.Lock()
defer m.mu.Unlock()
if m.lastCleanup.IsZero() || now.Sub(m.lastCleanup) >= m.window {
m.cleanupLocked(now)
m.lastCleanup = now
}
entry, exists := m.entries[key]
if !exists {
if len(m.entries) >= m.maxEntries {
m.evictOldestLocked()
}
m.entries[key] = &ipWindowCounter{
windowStart: now,
lastSeen: now,
count: 1,
}
return true
}
if now.Sub(entry.windowStart) >= m.window {
entry.windowStart = now
entry.lastSeen = now
entry.count = 1
return true
}
entry.lastSeen = now
if entry.count >= m.limit {
return false
}
entry.count++
return true
}
func (m *IPRateLimiter) cleanupLocked(now time.Time) {
for key, entry := range m.entries {
if now.Sub(entry.lastSeen) > m.ttl {
delete(m.entries, key)
}
}
}
func (m *IPRateLimiter) evictOldestLocked() {
var oldestKey string
var oldestTime time.Time
first := true
for key, entry := range m.entries {
if first || entry.lastSeen.Before(oldestTime) {
oldestKey = key
oldestTime = entry.lastSeen
first = false
}
}
if !first {
delete(m.entries, oldestKey)
}
}
func clientKey(r *http.Request) string {
host := strings.TrimSpace(r.RemoteAddr)
if parsedHost, _, err := net.SplitHostPort(host); err == nil {
host = parsedHost
}
if host == "" {
host = "unknown"
}
return host
}