package middleware import ( "net" "net/http" "net/url" "strings" ) type CORSMiddleware struct { allowedOrigins map[string]struct{} allowAll bool } func NewCORSMiddleware(origins []string) *CORSMiddleware { m := &CORSMiddleware{allowedOrigins: map[string]struct{}{}} for _, origin := range origins { if origin == "*" { m.allowAll = true continue } m.allowedOrigins[origin] = struct{}{} } return m } func (m *CORSMiddleware) Handler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get("Origin") if origin != "" && (m.allowAll || m.isAllowed(origin) || isSameHost(origin, r.Host)) { w.Header().Set("Access-Control-Allow-Origin", origin) w.Header().Set("Vary", "Origin") w.Header().Set("Access-Control-Allow-Credentials", "true") w.Header().Set("Access-Control-Allow-Methods", "GET,POST,PUT,PATCH,DELETE,OPTIONS") w.Header().Set("Access-Control-Allow-Headers", "Content-Type,Authorization") } if r.Method == http.MethodOptions { w.WriteHeader(http.StatusNoContent) return } next.ServeHTTP(w, r) }) } func (m *CORSMiddleware) isAllowed(origin string) bool { _, ok := m.allowedOrigins[origin] return ok } func isSameHost(origin string, requestHost string) bool { parsed, err := url.Parse(origin) if err != nil || parsed.Host == "" { return false } originHost := parsed.Hostname() reqHost := requestHost if strings.Contains(requestHost, ":") { if host, _, splitErr := net.SplitHostPort(requestHost); splitErr == nil { reqHost = host } } return strings.EqualFold(originHost, reqHost) }