134 lines
2.5 KiB
Go
134 lines
2.5 KiB
Go
package middleware
|
|
|
|
import (
|
|
"net"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
type ipWindowCounter struct {
|
|
windowStart time.Time
|
|
lastSeen time.Time
|
|
count int
|
|
}
|
|
|
|
type IPRateLimiter struct {
|
|
mu sync.Mutex
|
|
limit int
|
|
window time.Duration
|
|
ttl time.Duration
|
|
maxEntries int
|
|
lastCleanup time.Time
|
|
entries map[string]*ipWindowCounter
|
|
}
|
|
|
|
const defaultRateLimiterMaxEntries = 10000
|
|
|
|
func NewIPRateLimiter(limit int, window time.Duration, ttl time.Duration) *IPRateLimiter {
|
|
if limit <= 0 {
|
|
limit = 60
|
|
}
|
|
if window <= 0 {
|
|
window = time.Minute
|
|
}
|
|
if ttl <= 0 {
|
|
ttl = 10 * time.Minute
|
|
}
|
|
return &IPRateLimiter{
|
|
limit: limit,
|
|
window: window,
|
|
ttl: ttl,
|
|
maxEntries: defaultRateLimiterMaxEntries,
|
|
entries: map[string]*ipWindowCounter{},
|
|
}
|
|
}
|
|
|
|
func (m *IPRateLimiter) Handler(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if !m.allow(r) {
|
|
w.Header().Set("Retry-After", strconv.Itoa(int(m.window.Seconds())))
|
|
http.Error(w, "rate limit exceeded", http.StatusTooManyRequests)
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
func (m *IPRateLimiter) allow(r *http.Request) bool {
|
|
key := clientKey(r)
|
|
now := time.Now()
|
|
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
if m.lastCleanup.IsZero() || now.Sub(m.lastCleanup) >= m.window {
|
|
m.cleanupLocked(now)
|
|
m.lastCleanup = now
|
|
}
|
|
|
|
entry, exists := m.entries[key]
|
|
if !exists {
|
|
if len(m.entries) >= m.maxEntries {
|
|
m.evictOldestLocked()
|
|
}
|
|
m.entries[key] = &ipWindowCounter{
|
|
windowStart: now,
|
|
lastSeen: now,
|
|
count: 1,
|
|
}
|
|
return true
|
|
}
|
|
|
|
if now.Sub(entry.windowStart) >= m.window {
|
|
entry.windowStart = now
|
|
entry.lastSeen = now
|
|
entry.count = 1
|
|
return true
|
|
}
|
|
|
|
entry.lastSeen = now
|
|
if entry.count >= m.limit {
|
|
return false
|
|
}
|
|
entry.count++
|
|
return true
|
|
}
|
|
|
|
func (m *IPRateLimiter) cleanupLocked(now time.Time) {
|
|
for key, entry := range m.entries {
|
|
if now.Sub(entry.lastSeen) > m.ttl {
|
|
delete(m.entries, key)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (m *IPRateLimiter) evictOldestLocked() {
|
|
var oldestKey string
|
|
var oldestTime time.Time
|
|
first := true
|
|
for key, entry := range m.entries {
|
|
if first || entry.lastSeen.Before(oldestTime) {
|
|
oldestKey = key
|
|
oldestTime = entry.lastSeen
|
|
first = false
|
|
}
|
|
}
|
|
if !first {
|
|
delete(m.entries, oldestKey)
|
|
}
|
|
}
|
|
|
|
func clientKey(r *http.Request) string {
|
|
host := strings.TrimSpace(r.RemoteAddr)
|
|
if parsedHost, _, err := net.SplitHostPort(host); err == nil {
|
|
host = parsedHost
|
|
}
|
|
if host == "" {
|
|
host = "unknown"
|
|
}
|
|
return host
|
|
}
|