add core lib
This commit is contained in:
62
.drone.yml
Normal file
62
.drone.yml
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
---
|
||||||
|
kind: pipeline
|
||||||
|
type: docker
|
||||||
|
name: go-core-ci
|
||||||
|
|
||||||
|
trigger:
|
||||||
|
branch:
|
||||||
|
- main
|
||||||
|
- develop
|
||||||
|
event:
|
||||||
|
- push
|
||||||
|
- pull_request
|
||||||
|
|
||||||
|
environment:
|
||||||
|
GOPROXY: https://nexus.beatrice.wtf/repository/go-group/,direct
|
||||||
|
GOPRIVATE: git.beatrice.wtf/panic.haus/*
|
||||||
|
GONOSUMDB: git.beatrice.wtf/panic.haus/*
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: test
|
||||||
|
image: golang:1.26
|
||||||
|
commands:
|
||||||
|
- go mod download
|
||||||
|
- go test ./... -count=1
|
||||||
|
|
||||||
|
- name: build
|
||||||
|
image: golang:1.26
|
||||||
|
commands:
|
||||||
|
- go build ./...
|
||||||
|
|
||||||
|
---
|
||||||
|
kind: pipeline
|
||||||
|
type: docker
|
||||||
|
name: go-core-release
|
||||||
|
|
||||||
|
trigger:
|
||||||
|
event:
|
||||||
|
- tag
|
||||||
|
ref:
|
||||||
|
- refs/tags/v*
|
||||||
|
|
||||||
|
environment:
|
||||||
|
GOPROXY: https://nexus.beatrice.wtf/repository/go-group/,direct
|
||||||
|
GOPRIVATE: git.beatrice.wtf/panic.haus/*
|
||||||
|
GONOSUMDB: git.beatrice.wtf/panic.haus/*
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: test
|
||||||
|
image: golang:1.26
|
||||||
|
commands:
|
||||||
|
- go mod download
|
||||||
|
- go test ./... -count=1
|
||||||
|
|
||||||
|
- name: build
|
||||||
|
image: golang:1.26
|
||||||
|
commands:
|
||||||
|
- go build ./...
|
||||||
|
|
||||||
|
- name: warm-proxy
|
||||||
|
image: golang:1.26
|
||||||
|
commands:
|
||||||
|
- go list -m git.beatrice.wtf/panic.haus/go-core@${DRONE_TAG}
|
||||||
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
coverage.out
|
||||||
|
bin/
|
||||||
12
README.md
Normal file
12
README.md
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
# go-core
|
||||||
|
|
||||||
|
Reusable backend infrastructure module.
|
||||||
|
|
||||||
|
## Packages
|
||||||
|
|
||||||
|
- `dotenv`: load dotenv candidates while preserving existing env vars.
|
||||||
|
- `dbpool`: PostgreSQL pool bootstrap for pgx.
|
||||||
|
- `migrate`: migration helpers around `golang-migrate`.
|
||||||
|
- `smtp`: SMTP mailer implementation.
|
||||||
|
- `middleware`: reusable HTTP middleware (`CORS`, IP rate limiter).
|
||||||
|
- `worker`: generic batch poller/runner utilities.
|
||||||
68
dbpool/postgres.go
Normal file
68
dbpool/postgres.go
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
package dbpool
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PoolConfig struct {
|
||||||
|
MaxOpenConns int
|
||||||
|
MinIdleConns int
|
||||||
|
MaxConnLifetime time.Duration
|
||||||
|
MaxConnIdleTime time.Duration
|
||||||
|
HealthCheckPeriod time.Duration
|
||||||
|
ConnectionAcquireWait time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPostgresPool(ctx context.Context, databaseURL string, pool PoolConfig) (*pgxpool.Pool, error) {
|
||||||
|
cfg, err := pgxpool.ParseConfig(databaseURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse postgres pool config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if pool.MaxOpenConns <= 0 {
|
||||||
|
pool.MaxOpenConns = 25
|
||||||
|
}
|
||||||
|
if pool.MinIdleConns < 0 {
|
||||||
|
pool.MinIdleConns = 0
|
||||||
|
}
|
||||||
|
if pool.MinIdleConns > pool.MaxOpenConns {
|
||||||
|
pool.MinIdleConns = pool.MaxOpenConns
|
||||||
|
}
|
||||||
|
if pool.MaxConnLifetime <= 0 {
|
||||||
|
pool.MaxConnLifetime = 30 * time.Minute
|
||||||
|
}
|
||||||
|
if pool.MaxConnIdleTime <= 0 {
|
||||||
|
pool.MaxConnIdleTime = 5 * time.Minute
|
||||||
|
}
|
||||||
|
if pool.HealthCheckPeriod <= 0 {
|
||||||
|
pool.HealthCheckPeriod = time.Minute
|
||||||
|
}
|
||||||
|
if pool.ConnectionAcquireWait <= 0 {
|
||||||
|
pool.ConnectionAcquireWait = 10 * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.MaxConns = int32(pool.MaxOpenConns)
|
||||||
|
cfg.MinConns = int32(pool.MinIdleConns)
|
||||||
|
cfg.MaxConnLifetime = pool.MaxConnLifetime
|
||||||
|
cfg.MaxConnIdleTime = pool.MaxConnIdleTime
|
||||||
|
cfg.HealthCheckPeriod = pool.HealthCheckPeriod
|
||||||
|
|
||||||
|
pingCtx, cancel := context.WithTimeout(ctx, pool.ConnectionAcquireWait)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
poolConn, err := pgxpool.NewWithConfig(ctx, cfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("open postgres pool: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := poolConn.Ping(pingCtx); err != nil {
|
||||||
|
poolConn.Close()
|
||||||
|
return nil, fmt.Errorf("ping postgres: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return poolConn, nil
|
||||||
|
}
|
||||||
26
dbpool/postgres_test.go
Normal file
26
dbpool/postgres_test.go
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
package dbpool
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewPostgresPool_ParseConfigError(t *testing.T) {
|
||||||
|
_, err := NewPostgresPool(context.Background(), "://invalid-url", PoolConfig{})
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "parse postgres pool config") {
|
||||||
|
t.Fatalf("expected parse config error, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewPostgresPool_PingError(t *testing.T) {
|
||||||
|
_, err := NewPostgresPool(
|
||||||
|
context.Background(),
|
||||||
|
"postgres://postgres:postgres@127.0.0.1:1/appdb?sslmode=disable",
|
||||||
|
PoolConfig{ConnectionAcquireWait: 20 * time.Millisecond},
|
||||||
|
)
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "ping postgres") {
|
||||||
|
t.Fatalf("expected ping error, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
52
dotenv/dotenv.go
Normal file
52
dotenv/dotenv.go
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// LoadDotEnvCandidates loads KEY=VALUE pairs from the first existing file in paths.
|
||||||
|
// Existing process env vars are preserved (file values only fill missing keys).
|
||||||
|
func LoadDotEnvCandidates(paths []string) {
|
||||||
|
for _, path := range paths {
|
||||||
|
if loadDotEnvFile(path) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadDotEnvFile(path string) bool {
|
||||||
|
file, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(file)
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := strings.TrimSpace(scanner.Text())
|
||||||
|
if line == "" || strings.HasPrefix(line, "#") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
key, value, ok := strings.Cut(line, "=")
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
key = strings.TrimSpace(key)
|
||||||
|
value = strings.TrimSpace(value)
|
||||||
|
value = strings.Trim(value, `"'`)
|
||||||
|
if key == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, exists := os.LookupEnv(key); exists {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
_ = os.Setenv(key, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
76
dotenv/dotenv_test.go
Normal file
76
dotenv/dotenv_test.go
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLoadDotEnvFileParsesAndPreservesExistingValues(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, ".env")
|
||||||
|
content := "" +
|
||||||
|
"# comment\n" +
|
||||||
|
"KEY1 = value1\n" +
|
||||||
|
"KEY2='quoted value'\n" +
|
||||||
|
"KEY3=\"double quoted\"\n" +
|
||||||
|
"INVALID_LINE\n" +
|
||||||
|
"=missing_key\n"
|
||||||
|
|
||||||
|
if err := os.WriteFile(path, []byte(content), 0o600); err != nil {
|
||||||
|
t.Fatalf("write env file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Setenv("KEY1", "existing")
|
||||||
|
_ = os.Unsetenv("KEY2")
|
||||||
|
_ = os.Unsetenv("KEY3")
|
||||||
|
|
||||||
|
if ok := loadDotEnvFile(path); !ok {
|
||||||
|
t.Fatal("expected loadDotEnvFile to return true")
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := os.Getenv("KEY1"); got != "existing" {
|
||||||
|
t.Fatalf("expected KEY1 to preserve existing value, got %q", got)
|
||||||
|
}
|
||||||
|
if got := os.Getenv("KEY2"); got != "quoted value" {
|
||||||
|
t.Fatalf("expected KEY2 parsed value, got %q", got)
|
||||||
|
}
|
||||||
|
if got := os.Getenv("KEY3"); got != "double quoted" {
|
||||||
|
t.Fatalf("expected KEY3 parsed value, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadDotEnvCandidatesStopsAtFirstExistingFile(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
first := filepath.Join(dir, "first.env")
|
||||||
|
second := filepath.Join(dir, "second.env")
|
||||||
|
|
||||||
|
if err := os.WriteFile(first, []byte("KEY_A=first\nKEY_B=from_first\n"), 0o600); err != nil {
|
||||||
|
t.Fatalf("write first env file: %v", err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(second, []byte("KEY_A=second\nKEY_C=from_second\n"), 0o600); err != nil {
|
||||||
|
t.Fatalf("write second env file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = os.Unsetenv("KEY_A")
|
||||||
|
_ = os.Unsetenv("KEY_B")
|
||||||
|
_ = os.Unsetenv("KEY_C")
|
||||||
|
|
||||||
|
LoadDotEnvCandidates([]string{filepath.Join(dir, "missing.env"), first, second})
|
||||||
|
|
||||||
|
if got := os.Getenv("KEY_A"); got != "first" {
|
||||||
|
t.Fatalf("expected KEY_A from first file, got %q", got)
|
||||||
|
}
|
||||||
|
if got := os.Getenv("KEY_B"); got != "from_first" {
|
||||||
|
t.Fatalf("expected KEY_B from first file, got %q", got)
|
||||||
|
}
|
||||||
|
if got := os.Getenv("KEY_C"); got != "" {
|
||||||
|
t.Fatalf("expected KEY_C to remain unset, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadDotEnvFileReturnsFalseWhenMissing(t *testing.T) {
|
||||||
|
if ok := loadDotEnvFile(filepath.Join(t.TempDir(), "does-not-exist.env")); ok {
|
||||||
|
t.Fatal("expected loadDotEnvFile to return false for missing file")
|
||||||
|
}
|
||||||
|
}
|
||||||
17
go.mod
Normal file
17
go.mod
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
module git.beatrice.wtf/panic.haus/go-core
|
||||||
|
|
||||||
|
go 1.25
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/golang-migrate/migrate/v4 v4.19.1
|
||||||
|
github.com/jackc/pgx/v5 v5.8.0
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||||
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||||
|
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||||
|
github.com/lib/pq v1.10.9 // indirect
|
||||||
|
golang.org/x/sync v0.18.0 // indirect
|
||||||
|
golang.org/x/text v0.31.0 // indirect
|
||||||
|
)
|
||||||
93
go.sum
Normal file
93
go.sum
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0=
|
||||||
|
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
|
||||||
|
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
||||||
|
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||||
|
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
|
||||||
|
github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M=
|
||||||
|
github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE=
|
||||||
|
github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk=
|
||||||
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||||
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/dhui/dktest v0.4.6 h1:+DPKyScKSEp3VLtbMDHcUq6V5Lm5zfZZVb0Sk7Ahom4=
|
||||||
|
github.com/dhui/dktest v0.4.6/go.mod h1:JHTSYDtKkvFNFHJKqCzVzqXecyv+tKt8EzceOmQOgbU=
|
||||||
|
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
||||||
|
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
|
||||||
|
github.com/docker/docker v28.3.3+incompatible h1:Dypm25kh4rmk49v1eiVbsAtpAsYURjYkaKubwuBdxEI=
|
||||||
|
github.com/docker/docker v28.3.3+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
||||||
|
github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c=
|
||||||
|
github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc=
|
||||||
|
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
|
||||||
|
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
||||||
|
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
|
||||||
|
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||||
|
github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug=
|
||||||
|
github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0=
|
||||||
|
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||||
|
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||||
|
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||||
|
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||||
|
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
||||||
|
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||||
|
github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY=
|
||||||
|
github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||||
|
github.com/golang-migrate/migrate/v4 v4.19.1 h1:OCyb44lFuQfYXYLx1SCxPZQGU7mcaZ7gH9yH4jSFbBA=
|
||||||
|
github.com/golang-migrate/migrate/v4 v4.19.1/go.mod h1:CTcgfjxhaUtsLipnLoQRWCrjYXycRz/g5+RWDuYgPrE=
|
||||||
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
|
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||||
|
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||||
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||||
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||||
|
github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo=
|
||||||
|
github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw=
|
||||||
|
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||||
|
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||||
|
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||||
|
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||||
|
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
|
||||||
|
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
|
||||||
|
github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0=
|
||||||
|
github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y=
|
||||||
|
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
||||||
|
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
||||||
|
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||||
|
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||||
|
github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug=
|
||||||
|
github.com/opencontainers/image-spec v1.1.0/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM=
|
||||||
|
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||||
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||||
|
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
|
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||||
|
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
|
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
|
||||||
|
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
|
||||||
|
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 h1:F7Jx+6hwnZ41NSFTO5q4LYDtJRXBf2PD0rNBkeB/lus=
|
||||||
|
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0/go.mod h1:UHB22Z8QsdRDrnAtX4PntOl36ajSxcdUMt1sF7Y6E7Q=
|
||||||
|
go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ=
|
||||||
|
go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I=
|
||||||
|
go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/WgbsdpcPoZE=
|
||||||
|
go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E=
|
||||||
|
go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4=
|
||||||
|
go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0=
|
||||||
|
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
|
||||||
|
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
|
||||||
|
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
||||||
|
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||||
|
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||||
|
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||||
|
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
|
||||||
|
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||||
|
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
||||||
|
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
||||||
|
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
|
||||||
|
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
67
middleware/cors.go
Normal file
67
middleware/cors.go
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CORSMiddleware struct {
|
||||||
|
allowedOrigins map[string]struct{}
|
||||||
|
allowAll bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCORSMiddleware(origins []string) *CORSMiddleware {
|
||||||
|
m := &CORSMiddleware{allowedOrigins: map[string]struct{}{}}
|
||||||
|
for _, origin := range origins {
|
||||||
|
if origin == "*" {
|
||||||
|
m.allowAll = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
m.allowedOrigins[origin] = struct{}{}
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *CORSMiddleware) Handler(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
origin := r.Header.Get("Origin")
|
||||||
|
if origin != "" && (m.allowAll || m.isAllowed(origin) || isSameHost(origin, r.Host)) {
|
||||||
|
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||||
|
w.Header().Set("Vary", "Origin")
|
||||||
|
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||||
|
w.Header().Set("Access-Control-Allow-Methods", "GET,POST,PUT,PATCH,DELETE,OPTIONS")
|
||||||
|
w.Header().Set("Access-Control-Allow-Headers", "Content-Type,Authorization")
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Method == http.MethodOptions {
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *CORSMiddleware) isAllowed(origin string) bool {
|
||||||
|
_, ok := m.allowedOrigins[origin]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func isSameHost(origin string, requestHost string) bool {
|
||||||
|
parsed, err := url.Parse(origin)
|
||||||
|
if err != nil || parsed.Host == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
originHost := parsed.Hostname()
|
||||||
|
reqHost := requestHost
|
||||||
|
if strings.Contains(requestHost, ":") {
|
||||||
|
if host, _, splitErr := net.SplitHostPort(requestHost); splitErr == nil {
|
||||||
|
reqHost = host
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.EqualFold(originHost, reqHost)
|
||||||
|
}
|
||||||
137
middleware/cors_test.go
Normal file
137
middleware/cors_test.go
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCORSHandlerAllowsConfiguredOrigin(t *testing.T) {
|
||||||
|
mw := NewCORSMiddleware([]string{"https://allowed.example"})
|
||||||
|
nextCalled := false
|
||||||
|
h := mw.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
nextCalled = true
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Set("Origin", "https://allowed.example")
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
if rr.Header().Get("Access-Control-Allow-Origin") != "https://allowed.example" {
|
||||||
|
t.Fatalf("unexpected allow-origin header: %q", rr.Header().Get("Access-Control-Allow-Origin"))
|
||||||
|
}
|
||||||
|
if rr.Header().Get("Access-Control-Allow-Credentials") != "true" {
|
||||||
|
t.Fatalf("expected allow-credentials true, got %q", rr.Header().Get("Access-Control-Allow-Credentials"))
|
||||||
|
}
|
||||||
|
if !nextCalled {
|
||||||
|
t.Fatal("expected next handler to be called")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCORSHandlerAllowsAnyOriginWhenWildcardConfigured(t *testing.T) {
|
||||||
|
mw := NewCORSMiddleware([]string{"*"})
|
||||||
|
h := mw.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Set("Origin", "https://random.example")
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Header().Get("Access-Control-Allow-Origin") != "https://random.example" {
|
||||||
|
t.Fatalf("expected wildcard middleware to echo origin, got %q", rr.Header().Get("Access-Control-Allow-Origin"))
|
||||||
|
}
|
||||||
|
if rr.Header().Get("Access-Control-Allow-Credentials") != "true" {
|
||||||
|
t.Fatalf("expected allow-credentials true, got %q", rr.Header().Get("Access-Control-Allow-Credentials"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCORSHandlerAllowsSameHostOrigin(t *testing.T) {
|
||||||
|
mw := NewCORSMiddleware(nil)
|
||||||
|
h := mw.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "http://api.example:8080/path", nil)
|
||||||
|
req.Host = "api.example:8080"
|
||||||
|
req.Header.Set("Origin", "https://api.example")
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Header().Get("Access-Control-Allow-Origin") != "https://api.example" {
|
||||||
|
t.Fatalf("expected same-host origin to be allowed, got %q", rr.Header().Get("Access-Control-Allow-Origin"))
|
||||||
|
}
|
||||||
|
if rr.Header().Get("Access-Control-Allow-Credentials") != "true" {
|
||||||
|
t.Fatalf("expected allow-credentials true, got %q", rr.Header().Get("Access-Control-Allow-Credentials"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCORSHandlerDoesNotSetHeadersForDisallowedOrigin(t *testing.T) {
|
||||||
|
mw := NewCORSMiddleware([]string{"https://allowed.example"})
|
||||||
|
h := mw.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Set("Origin", "https://blocked.example")
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Header().Get("Access-Control-Allow-Origin") != "" {
|
||||||
|
t.Fatalf("expected no allow-origin header, got %q", rr.Header().Get("Access-Control-Allow-Origin"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCORSHandlerOptionsShortCircuits(t *testing.T) {
|
||||||
|
mw := NewCORSMiddleware([]string{"https://allowed.example"})
|
||||||
|
nextCalled := false
|
||||||
|
h := mw.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
nextCalled = true
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodOptions, "/", nil)
|
||||||
|
req.Header.Set("Origin", "https://allowed.example")
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusNoContent {
|
||||||
|
t.Fatalf("expected 204, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
if rr.Header().Get("Access-Control-Allow-Credentials") != "true" {
|
||||||
|
t.Fatalf("expected allow-credentials true, got %q", rr.Header().Get("Access-Control-Allow-Credentials"))
|
||||||
|
}
|
||||||
|
if nextCalled {
|
||||||
|
t.Fatal("expected next handler not to be called for OPTIONS")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsSameHost(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
origin string
|
||||||
|
requestHost string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{name: "same host exact", origin: "https://api.example", requestHost: "api.example", want: true},
|
||||||
|
{name: "same host with port", origin: "https://api.example", requestHost: "api.example:8080", want: true},
|
||||||
|
{name: "case insensitive", origin: "https://API.EXAMPLE", requestHost: "api.example", want: true},
|
||||||
|
{name: "different host", origin: "https://api.example", requestHost: "other.example", want: false},
|
||||||
|
{name: "invalid origin", origin: "://bad", requestHost: "api.example", want: false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
got := isSameHost(tc.origin, tc.requestHost)
|
||||||
|
if got != tc.want {
|
||||||
|
t.Fatalf("isSameHost(%q, %q) = %v, want %v", tc.origin, tc.requestHost, got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
133
middleware/rate_limit.go
Normal file
133
middleware/rate_limit.go
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ipWindowCounter struct {
|
||||||
|
windowStart time.Time
|
||||||
|
lastSeen time.Time
|
||||||
|
count int
|
||||||
|
}
|
||||||
|
|
||||||
|
type IPRateLimiter struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
limit int
|
||||||
|
window time.Duration
|
||||||
|
ttl time.Duration
|
||||||
|
maxEntries int
|
||||||
|
lastCleanup time.Time
|
||||||
|
entries map[string]*ipWindowCounter
|
||||||
|
}
|
||||||
|
|
||||||
|
const defaultRateLimiterMaxEntries = 10000
|
||||||
|
|
||||||
|
func NewIPRateLimiter(limit int, window time.Duration, ttl time.Duration) *IPRateLimiter {
|
||||||
|
if limit <= 0 {
|
||||||
|
limit = 60
|
||||||
|
}
|
||||||
|
if window <= 0 {
|
||||||
|
window = time.Minute
|
||||||
|
}
|
||||||
|
if ttl <= 0 {
|
||||||
|
ttl = 10 * time.Minute
|
||||||
|
}
|
||||||
|
return &IPRateLimiter{
|
||||||
|
limit: limit,
|
||||||
|
window: window,
|
||||||
|
ttl: ttl,
|
||||||
|
maxEntries: defaultRateLimiterMaxEntries,
|
||||||
|
entries: map[string]*ipWindowCounter{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *IPRateLimiter) Handler(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if !m.allow(r) {
|
||||||
|
w.Header().Set("Retry-After", strconv.Itoa(int(m.window.Seconds())))
|
||||||
|
http.Error(w, "rate limit exceeded", http.StatusTooManyRequests)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *IPRateLimiter) allow(r *http.Request) bool {
|
||||||
|
key := clientKey(r)
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if m.lastCleanup.IsZero() || now.Sub(m.lastCleanup) >= m.window {
|
||||||
|
m.cleanupLocked(now)
|
||||||
|
m.lastCleanup = now
|
||||||
|
}
|
||||||
|
|
||||||
|
entry, exists := m.entries[key]
|
||||||
|
if !exists {
|
||||||
|
if len(m.entries) >= m.maxEntries {
|
||||||
|
m.evictOldestLocked()
|
||||||
|
}
|
||||||
|
m.entries[key] = &ipWindowCounter{
|
||||||
|
windowStart: now,
|
||||||
|
lastSeen: now,
|
||||||
|
count: 1,
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if now.Sub(entry.windowStart) >= m.window {
|
||||||
|
entry.windowStart = now
|
||||||
|
entry.lastSeen = now
|
||||||
|
entry.count = 1
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
entry.lastSeen = now
|
||||||
|
if entry.count >= m.limit {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
entry.count++
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *IPRateLimiter) cleanupLocked(now time.Time) {
|
||||||
|
for key, entry := range m.entries {
|
||||||
|
if now.Sub(entry.lastSeen) > m.ttl {
|
||||||
|
delete(m.entries, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *IPRateLimiter) evictOldestLocked() {
|
||||||
|
var oldestKey string
|
||||||
|
var oldestTime time.Time
|
||||||
|
first := true
|
||||||
|
for key, entry := range m.entries {
|
||||||
|
if first || entry.lastSeen.Before(oldestTime) {
|
||||||
|
oldestKey = key
|
||||||
|
oldestTime = entry.lastSeen
|
||||||
|
first = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !first {
|
||||||
|
delete(m.entries, oldestKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func clientKey(r *http.Request) string {
|
||||||
|
host := strings.TrimSpace(r.RemoteAddr)
|
||||||
|
if parsedHost, _, err := net.SplitHostPort(host); err == nil {
|
||||||
|
host = parsedHost
|
||||||
|
}
|
||||||
|
if host == "" {
|
||||||
|
host = "unknown"
|
||||||
|
}
|
||||||
|
return host
|
||||||
|
}
|
||||||
177
middleware/rate_limit_test.go
Normal file
177
middleware/rate_limit_test.go
Normal file
@@ -0,0 +1,177 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIPRateLimiterRejectsExcessRequestsAcrossAuthPaths(t *testing.T) {
|
||||||
|
limiter := NewIPRateLimiter(2, time.Minute, 5*time.Minute)
|
||||||
|
h := limiter.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}))
|
||||||
|
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/auth/login", nil)
|
||||||
|
req.RemoteAddr = "203.0.113.10:12345"
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(rr, req)
|
||||||
|
if rr.Code != http.StatusNoContent {
|
||||||
|
t.Fatalf("request %d expected 204, got %d", i+1, rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
third := httptest.NewRequest(http.MethodPost, "/auth/login", nil)
|
||||||
|
third.RemoteAddr = "203.0.113.10:12345"
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(rr, third)
|
||||||
|
if rr.Code != http.StatusTooManyRequests {
|
||||||
|
t.Fatalf("expected 429 for third request, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
if rr.Header().Get("Retry-After") == "" {
|
||||||
|
t.Fatal("expected Retry-After header")
|
||||||
|
}
|
||||||
|
|
||||||
|
otherPath := httptest.NewRequest(http.MethodPost, "/auth/register", nil)
|
||||||
|
otherPath.RemoteAddr = "203.0.113.10:12345"
|
||||||
|
rr2 := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(rr2, otherPath)
|
||||||
|
if rr2.Code != http.StatusTooManyRequests {
|
||||||
|
t.Fatalf("expected shared limiter bucket per IP across paths, got %d", rr2.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIPRateLimiterCleanupRunsPeriodicallyWithoutEntryThreshold(t *testing.T) {
|
||||||
|
limiter := NewIPRateLimiter(5, 20*time.Millisecond, 20*time.Millisecond)
|
||||||
|
|
||||||
|
first := httptest.NewRequest(http.MethodPost, "/auth/login", nil)
|
||||||
|
first.RemoteAddr = "203.0.113.10:12345"
|
||||||
|
if !limiter.allow(first) {
|
||||||
|
t.Fatal("expected first request to be allowed")
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
second := httptest.NewRequest(http.MethodPost, "/auth/login", nil)
|
||||||
|
second.RemoteAddr = "203.0.113.11:12345"
|
||||||
|
if !limiter.allow(second) {
|
||||||
|
t.Fatal("expected second request to be allowed")
|
||||||
|
}
|
||||||
|
|
||||||
|
limiter.mu.Lock()
|
||||||
|
defer limiter.mu.Unlock()
|
||||||
|
if len(limiter.entries) != 1 {
|
||||||
|
t.Fatalf("expected stale entries to be cleaned, got %d entries", len(limiter.entries))
|
||||||
|
}
|
||||||
|
if _, exists := limiter.entries["203.0.113.11"]; !exists {
|
||||||
|
t.Fatal("expected current client key to remain after cleanup")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIPRateLimiterEvictsOldestEntryAtCapacity(t *testing.T) {
|
||||||
|
limiter := NewIPRateLimiter(5, time.Minute, time.Hour)
|
||||||
|
limiter.maxEntries = 2
|
||||||
|
|
||||||
|
req1 := httptest.NewRequest(http.MethodPost, "/auth/login", nil)
|
||||||
|
req1.RemoteAddr = "203.0.113.10:12345"
|
||||||
|
if !limiter.allow(req1) {
|
||||||
|
t.Fatal("expected first request to be allowed")
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(2 * time.Millisecond)
|
||||||
|
|
||||||
|
req2 := httptest.NewRequest(http.MethodPost, "/auth/login", nil)
|
||||||
|
req2.RemoteAddr = "203.0.113.11:12345"
|
||||||
|
if !limiter.allow(req2) {
|
||||||
|
t.Fatal("expected second request to be allowed")
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(2 * time.Millisecond)
|
||||||
|
|
||||||
|
req3 := httptest.NewRequest(http.MethodPost, "/auth/login", nil)
|
||||||
|
req3.RemoteAddr = "203.0.113.12:12345"
|
||||||
|
if !limiter.allow(req3) {
|
||||||
|
t.Fatal("expected third request to be allowed")
|
||||||
|
}
|
||||||
|
|
||||||
|
limiter.mu.Lock()
|
||||||
|
defer limiter.mu.Unlock()
|
||||||
|
if len(limiter.entries) != 2 {
|
||||||
|
t.Fatalf("expected limiter size to remain capped at 2, got %d", len(limiter.entries))
|
||||||
|
}
|
||||||
|
if _, exists := limiter.entries["203.0.113.10"]; exists {
|
||||||
|
t.Fatal("expected oldest entry to be evicted")
|
||||||
|
}
|
||||||
|
if _, exists := limiter.entries["203.0.113.11"]; !exists {
|
||||||
|
t.Fatal("expected second entry to remain")
|
||||||
|
}
|
||||||
|
if _, exists := limiter.entries["203.0.113.12"]; !exists {
|
||||||
|
t.Fatal("expected newest entry to remain")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIPRateLimiterWindowResetAllowsRequestAgain(t *testing.T) {
|
||||||
|
limiter := NewIPRateLimiter(1, 20*time.Millisecond, time.Minute)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/auth/login", nil)
|
||||||
|
req.RemoteAddr = "203.0.113.10:12345"
|
||||||
|
if !limiter.allow(req) {
|
||||||
|
t.Fatal("expected first request to be allowed")
|
||||||
|
}
|
||||||
|
if limiter.allow(req) {
|
||||||
|
t.Fatal("expected second request in same window to be blocked")
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(25 * time.Millisecond)
|
||||||
|
if !limiter.allow(req) {
|
||||||
|
t.Fatal("expected request to be allowed after window reset")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientKeyVariants(t *testing.T) {
|
||||||
|
t.Run("host port", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.RemoteAddr = "203.0.113.10:4321"
|
||||||
|
if got := clientKey(req); got != "203.0.113.10" {
|
||||||
|
t.Fatalf("expected parsed host, got %q", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("raw host", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.RemoteAddr = "203.0.113.11"
|
||||||
|
if got := clientKey(req); got != "203.0.113.11" {
|
||||||
|
t.Fatalf("expected raw host, got %q", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty uses unknown", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.RemoteAddr = " "
|
||||||
|
if got := clientKey(req); got != "unknown" {
|
||||||
|
t.Fatalf("expected unknown fallback, got %q", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewIPRateLimiterAppliesDefaults(t *testing.T) {
|
||||||
|
limiter := NewIPRateLimiter(0, 0, 0)
|
||||||
|
|
||||||
|
if limiter.limit != 60 {
|
||||||
|
t.Fatalf("expected default limit 60, got %d", limiter.limit)
|
||||||
|
}
|
||||||
|
if limiter.window != time.Minute {
|
||||||
|
t.Fatalf("expected default window %v, got %v", time.Minute, limiter.window)
|
||||||
|
}
|
||||||
|
if limiter.ttl != 10*time.Minute {
|
||||||
|
t.Fatalf("expected default ttl %v, got %v", 10*time.Minute, limiter.ttl)
|
||||||
|
}
|
||||||
|
if limiter.maxEntries != defaultRateLimiterMaxEntries {
|
||||||
|
t.Fatalf("expected maxEntries %d, got %d", defaultRateLimiterMaxEntries, limiter.maxEntries)
|
||||||
|
}
|
||||||
|
if limiter.entries == nil {
|
||||||
|
t.Fatal("expected entries map to be initialized")
|
||||||
|
}
|
||||||
|
}
|
||||||
150
migrate/migrate.go
Normal file
150
migrate/migrate.go
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
package migrate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-migrate/migrate/v4"
|
||||||
|
_ "github.com/golang-migrate/migrate/v4/database/postgres"
|
||||||
|
_ "github.com/golang-migrate/migrate/v4/source/file"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultMigrationsPath = "file://migrations"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MigrationConfig struct {
|
||||||
|
Path string
|
||||||
|
LockTimeout time.Duration
|
||||||
|
StartupTimeout time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
type MigrationStatus struct {
|
||||||
|
Version uint
|
||||||
|
Dirty bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func RunMigrationsUp(databaseURL string, cfg MigrationConfig) error {
|
||||||
|
exec := func() error {
|
||||||
|
m, err := newMigrator(databaseURL, cfg.Path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer closeMigrator(m)
|
||||||
|
|
||||||
|
if err := m.Up(); err != nil && !errors.Is(err, migrate.ErrNoChange) {
|
||||||
|
return fmt.Errorf("run migrations up: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return runWithTimeout(cfg.StartupTimeout, exec)
|
||||||
|
}
|
||||||
|
|
||||||
|
func RunMigrationsDown(databaseURL string, cfg MigrationConfig, steps int) error {
|
||||||
|
m, err := newMigrator(databaseURL, cfg.Path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer closeMigrator(m)
|
||||||
|
|
||||||
|
if steps > 0 {
|
||||||
|
if err := m.Steps(-steps); err != nil && !errors.Is(err, migrate.ErrNoChange) {
|
||||||
|
return fmt.Errorf("run migrations down %d steps: %w", steps, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.Down(); err != nil && !errors.Is(err, migrate.ErrNoChange) {
|
||||||
|
return fmt.Errorf("run migrations down all: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ForceMigrationVersion(databaseURL string, cfg MigrationConfig, version int) error {
|
||||||
|
m, err := newMigrator(databaseURL, cfg.Path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer closeMigrator(m)
|
||||||
|
|
||||||
|
if err := m.Force(version); err != nil {
|
||||||
|
return fmt.Errorf("force migration version %d: %w", version, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetMigrationStatus(databaseURL string, cfg MigrationConfig) (MigrationStatus, error) {
|
||||||
|
m, err := newMigrator(databaseURL, cfg.Path)
|
||||||
|
if err != nil {
|
||||||
|
return MigrationStatus{}, err
|
||||||
|
}
|
||||||
|
defer closeMigrator(m)
|
||||||
|
|
||||||
|
version, dirty, err := m.Version()
|
||||||
|
if errors.Is(err, migrate.ErrNilVersion) {
|
||||||
|
return MigrationStatus{}, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return MigrationStatus{}, fmt.Errorf("read migration version: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return MigrationStatus{Version: version, Dirty: dirty}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ResolveMigrationsPath(pathValue string) string {
|
||||||
|
if pathValue == "" {
|
||||||
|
pathValue = defaultMigrationsPath
|
||||||
|
}
|
||||||
|
if filepath.IsAbs(pathValue) {
|
||||||
|
return "file://" + pathValue
|
||||||
|
}
|
||||||
|
if len(pathValue) >= len("file://") && pathValue[:len("file://")] == "file://" {
|
||||||
|
return pathValue
|
||||||
|
}
|
||||||
|
wd, err := os.Getwd()
|
||||||
|
if err != nil {
|
||||||
|
return defaultMigrationsPath
|
||||||
|
}
|
||||||
|
return "file://" + filepath.Join(wd, pathValue)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMigrator(databaseURL string, path string) (*migrate.Migrate, error) {
|
||||||
|
migrationsPath := ResolveMigrationsPath(path)
|
||||||
|
m, err := migrate.New(migrationsPath, databaseURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("open migration driver (%s): %w", migrationsPath, err)
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func closeMigrator(m *migrate.Migrate) {
|
||||||
|
if m == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
sourceErr, dbErr := m.Close()
|
||||||
|
if sourceErr != nil || dbErr != nil {
|
||||||
|
_ = sourceErr
|
||||||
|
_ = dbErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func runWithTimeout(timeout time.Duration, fn func() error) error {
|
||||||
|
if timeout <= 0 {
|
||||||
|
return fn()
|
||||||
|
}
|
||||||
|
done := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
done <- fn()
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-done:
|
||||||
|
return err
|
||||||
|
case <-time.After(timeout):
|
||||||
|
return fmt.Errorf("migration timed out after %s", timeout)
|
||||||
|
}
|
||||||
|
}
|
||||||
67
migrate/migrate_integration_test.go
Normal file
67
migrate/migrate_integration_test.go
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
package migrate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRunMigrationsUpTwice(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
databaseURL := os.Getenv("INTEGRATION_DB_URL")
|
||||||
|
if databaseURL == "" {
|
||||||
|
t.Skip("set INTEGRATION_DB_URL to run migration integration tests")
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := MigrationConfig{Path: resolveMigrationsPath()}
|
||||||
|
if err := RunMigrationsUp(databaseURL, cfg); err != nil {
|
||||||
|
t.Fatalf("first migrate up: %v", err)
|
||||||
|
}
|
||||||
|
if err := RunMigrationsUp(databaseURL, cfg); err != nil {
|
||||||
|
t.Fatalf("second migrate up: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
status, err := GetMigrationStatus(databaseURL, cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("get migration status: %v", err)
|
||||||
|
}
|
||||||
|
if status.Dirty {
|
||||||
|
t.Fatalf("expected non-dirty migration status")
|
||||||
|
}
|
||||||
|
if status.Version == 0 {
|
||||||
|
t.Fatalf("expected migration version > 0 after up")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveMigrationsPath() string {
|
||||||
|
configured := os.Getenv("DB_MIGRATIONS_PATH")
|
||||||
|
if configured == "" {
|
||||||
|
configured = "migrations"
|
||||||
|
}
|
||||||
|
|
||||||
|
if filepath.IsAbs(configured) {
|
||||||
|
return configured
|
||||||
|
}
|
||||||
|
if len(configured) >= len("file://") && configured[:len("file://")] == "file://" {
|
||||||
|
return configured
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(configured); err == nil {
|
||||||
|
return configured
|
||||||
|
}
|
||||||
|
|
||||||
|
_, thisFile, _, ok := runtime.Caller(0)
|
||||||
|
if !ok {
|
||||||
|
return configured
|
||||||
|
}
|
||||||
|
moduleRoot := filepath.Clean(filepath.Join(filepath.Dir(thisFile), "..", ".."))
|
||||||
|
candidate := filepath.Join(moduleRoot, configured)
|
||||||
|
if _, err := os.Stat(candidate); err == nil {
|
||||||
|
return candidate
|
||||||
|
}
|
||||||
|
|
||||||
|
return configured
|
||||||
|
}
|
||||||
77
migrate/migrate_unit_test.go
Normal file
77
migrate/migrate_unit_test.go
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
package migrate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestResolveMigrationsPath(t *testing.T) {
|
||||||
|
if got := ResolveMigrationsPath(""); got != "file://migrations" && !strings.HasSuffix(got, "/migrations") {
|
||||||
|
t.Fatalf("expected default migrations path, got %q", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := ResolveMigrationsPath("file:///tmp/migs"); got != "file:///tmp/migs" {
|
||||||
|
t.Fatalf("expected passthrough file URI, got %q", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
abs := filepath.Join(string(os.PathSeparator), "tmp", "migs")
|
||||||
|
if got := ResolveMigrationsPath(abs); got != "file://"+abs {
|
||||||
|
t.Fatalf("expected absolute path conversion, got %q", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
rel := "migrations"
|
||||||
|
got := ResolveMigrationsPath(rel)
|
||||||
|
if !strings.HasPrefix(got, "file://") || !strings.HasSuffix(got, rel) {
|
||||||
|
t.Fatalf("expected file URI ending with %q, got %q", rel, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunWithTimeout(t *testing.T) {
|
||||||
|
if err := runWithTimeout(0, func() error { return nil }); err != nil {
|
||||||
|
t.Fatalf("expected no error for immediate execution, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := runWithTimeout(50*time.Millisecond, func() error {
|
||||||
|
time.Sleep(5 * time.Millisecond)
|
||||||
|
return nil
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("expected fn to finish before timeout, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := runWithTimeout(10*time.Millisecond, func() error {
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "timed out") {
|
||||||
|
t.Fatalf("expected timeout error, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := errors.New("boom")
|
||||||
|
err = runWithTimeout(20*time.Millisecond, func() error { return expected })
|
||||||
|
if !errors.Is(err, expected) {
|
||||||
|
t.Fatalf("expected function error propagation, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMigrationAPIs_InvalidSourcePath(t *testing.T) {
|
||||||
|
cfg := MigrationConfig{Path: "definitely-missing-migrations-dir"}
|
||||||
|
|
||||||
|
if err := RunMigrationsUp("postgres://ignored", cfg); err == nil {
|
||||||
|
t.Fatalf("expected RunMigrationsUp error for missing source path")
|
||||||
|
}
|
||||||
|
if err := RunMigrationsDown("postgres://ignored", cfg, 1); err == nil {
|
||||||
|
t.Fatalf("expected RunMigrationsDown error for missing source path")
|
||||||
|
}
|
||||||
|
if err := ForceMigrationVersion("postgres://ignored", cfg, 1); err == nil {
|
||||||
|
t.Fatalf("expected ForceMigrationVersion error for missing source path")
|
||||||
|
}
|
||||||
|
if _, err := GetMigrationStatus("postgres://ignored", cfg); err == nil {
|
||||||
|
t.Fatalf("expected GetMigrationStatus error for missing source path")
|
||||||
|
}
|
||||||
|
|
||||||
|
closeMigrator(nil)
|
||||||
|
}
|
||||||
345
smtp/smtp_mailer.go
Normal file
345
smtp/smtp_mailer.go
Normal file
@@ -0,0 +1,345 @@
|
|||||||
|
package smtp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/tls"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/smtp"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SMTPMode string
|
||||||
|
|
||||||
|
const (
|
||||||
|
SMTPModeSSL SMTPMode = "ssl"
|
||||||
|
SMTPModeTLS SMTPMode = "tls"
|
||||||
|
SMTPModeUnencrypted SMTPMode = "unencrypted"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SMTPConfig struct {
|
||||||
|
Host string
|
||||||
|
Port int
|
||||||
|
Username string
|
||||||
|
Password string
|
||||||
|
From string
|
||||||
|
Mode SMTPMode
|
||||||
|
Auth string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Mailer interface {
|
||||||
|
Send(ctx context.Context, to string, subject string, htmlBody string, textBody string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type SMTPMailer struct {
|
||||||
|
cfg SMTPConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
const defaultSMTPOperationTimeout = 30 * time.Second
|
||||||
|
|
||||||
|
func NewSMTPMailer(cfg SMTPConfig) *SMTPMailer {
|
||||||
|
return &SMTPMailer{cfg: cfg}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *SMTPMailer) Send(ctx context.Context, to string, subject string, htmlBody string, textBody string) error {
|
||||||
|
if strings.TrimSpace(m.cfg.Host) == "" || m.cfg.Port == 0 || strings.TrimSpace(m.cfg.From) == "" {
|
||||||
|
return fmt.Errorf("smtp is not configured")
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
if err := ctx.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
addr := fmt.Sprintf("%s:%d", m.cfg.Host, m.cfg.Port)
|
||||||
|
deadline := smtpSendDeadline(ctx, now)
|
||||||
|
if !deadline.After(now) {
|
||||||
|
if err := ctx.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return context.DeadlineExceeded
|
||||||
|
}
|
||||||
|
|
||||||
|
var client *smtp.Client
|
||||||
|
var err error
|
||||||
|
|
||||||
|
switch m.cfg.Mode {
|
||||||
|
case SMTPModeSSL:
|
||||||
|
client, err = dialSSL(ctx, addr, m.cfg.Host, deadline)
|
||||||
|
case SMTPModeUnencrypted:
|
||||||
|
client, err = dialPlain(ctx, addr, m.cfg.Host, deadline)
|
||||||
|
case SMTPModeTLS:
|
||||||
|
fallthrough
|
||||||
|
default:
|
||||||
|
client, err = dialStartTLS(ctx, addr, m.cfg.Host, deadline)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer client.Close()
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
_ = client.Close()
|
||||||
|
case <-done:
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
defer close(done)
|
||||||
|
|
||||||
|
if m.cfg.Username != "" {
|
||||||
|
if err := authenticate(client, m.cfg); err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := client.Mail(m.cfg.From); err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
return fmt.Errorf("smtp mail from: %w", err)
|
||||||
|
}
|
||||||
|
if err := client.Rcpt(to); err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
return fmt.Errorf("smtp rcpt to: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
writer, err := client.Data()
|
||||||
|
if err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
return fmt.Errorf("smtp data: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
message := buildMIMEMessage(m.cfg.From, to, subject, textBody, htmlBody)
|
||||||
|
if _, err := writer.Write([]byte(message)); err != nil {
|
||||||
|
_ = writer.Close()
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
return fmt.Errorf("smtp write body: %w", err)
|
||||||
|
}
|
||||||
|
if err := writer.Close(); err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
return fmt.Errorf("smtp close body: %w", err)
|
||||||
|
}
|
||||||
|
if err := client.Quit(); err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
return fmt.Errorf("smtp quit: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func smtpSendDeadline(ctx context.Context, now time.Time) time.Time {
|
||||||
|
if ctx != nil {
|
||||||
|
if deadline, ok := ctx.Deadline(); ok {
|
||||||
|
return deadline
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return now.Add(defaultSMTPOperationTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
func authenticate(client *smtp.Client, cfg SMTPConfig) error {
|
||||||
|
authMethod := strings.ToLower(strings.TrimSpace(cfg.Auth))
|
||||||
|
if authMethod == "" {
|
||||||
|
authMethod = "auto"
|
||||||
|
}
|
||||||
|
|
||||||
|
_, methodsRaw := client.Extension("AUTH")
|
||||||
|
methods := strings.ToUpper(methodsRaw)
|
||||||
|
|
||||||
|
tryLogin := strings.Contains(methods, "LOGIN")
|
||||||
|
tryPlain := strings.Contains(methods, "PLAIN")
|
||||||
|
|
||||||
|
switch authMethod {
|
||||||
|
case "none":
|
||||||
|
return nil
|
||||||
|
case "login":
|
||||||
|
return authWithLogin(client, cfg)
|
||||||
|
case "plain":
|
||||||
|
return authWithPlain(client, cfg)
|
||||||
|
case "auto":
|
||||||
|
if tryLogin {
|
||||||
|
if err := authWithLogin(client, cfg); err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if tryPlain {
|
||||||
|
if err := authWithPlain(client, cfg); err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !tryLogin && !tryPlain {
|
||||||
|
// Last fallback if server did not advertise methods.
|
||||||
|
if err := authWithLogin(client, cfg); err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err := authWithPlain(client, cfg); err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fmt.Errorf("smtp auth failed for available methods")
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unsupported smtp auth method: %s", authMethod)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func authWithPlain(client *smtp.Client, cfg SMTPConfig) error {
|
||||||
|
auth := smtp.PlainAuth("", cfg.Username, cfg.Password, cfg.Host)
|
||||||
|
if err := client.Auth(auth); err != nil {
|
||||||
|
return fmt.Errorf("smtp plain auth: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func authWithLogin(client *smtp.Client, cfg SMTPConfig) error {
|
||||||
|
if err := client.Auth(loginAuth{username: cfg.Username, password: cfg.Password}); err != nil {
|
||||||
|
return fmt.Errorf("smtp login auth: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type loginAuth struct {
|
||||||
|
username string
|
||||||
|
password string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a loginAuth) Start(_ *smtp.ServerInfo) (string, []byte, error) {
|
||||||
|
return "LOGIN", []byte{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a loginAuth) Next(fromServer []byte, more bool) ([]byte, error) {
|
||||||
|
if !more {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
challengeRaw := strings.TrimSpace(string(fromServer))
|
||||||
|
challenge := strings.ToLower(challengeRaw)
|
||||||
|
if decoded, err := base64.StdEncoding.DecodeString(challengeRaw); err == nil {
|
||||||
|
challenge = strings.ToLower(string(decoded))
|
||||||
|
}
|
||||||
|
switch {
|
||||||
|
case strings.Contains(challenge, "username"):
|
||||||
|
return []byte(a.username), nil
|
||||||
|
case strings.Contains(challenge, "password"):
|
||||||
|
return []byte(a.password), nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("unexpected smtp login challenge")
|
||||||
|
}
|
||||||
|
|
||||||
|
func dialSSL(ctx context.Context, addr string, host string, deadline time.Time) (*smtp.Client, error) {
|
||||||
|
timeout := time.Until(deadline)
|
||||||
|
if timeout <= 0 {
|
||||||
|
if err := ctx.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return nil, context.DeadlineExceeded
|
||||||
|
}
|
||||||
|
|
||||||
|
dialer := &tls.Dialer{
|
||||||
|
NetDialer: &net.Dialer{Timeout: timeout},
|
||||||
|
Config: &tls.Config{ServerName: host},
|
||||||
|
}
|
||||||
|
conn, err := dialer.DialContext(ctx, "tcp", addr)
|
||||||
|
if err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return nil, ctx.Err()
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("smtp ssl dial: %w", err)
|
||||||
|
}
|
||||||
|
_ = conn.SetDeadline(deadline)
|
||||||
|
client, err := smtp.NewClient(conn, host)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("smtp new client ssl: %w", err)
|
||||||
|
}
|
||||||
|
return client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func dialPlain(ctx context.Context, addr string, host string, deadline time.Time) (*smtp.Client, error) {
|
||||||
|
timeout := time.Until(deadline)
|
||||||
|
if timeout <= 0 {
|
||||||
|
if err := ctx.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return nil, context.DeadlineExceeded
|
||||||
|
}
|
||||||
|
|
||||||
|
dialer := &net.Dialer{Timeout: timeout}
|
||||||
|
conn, err := dialer.DialContext(ctx, "tcp", addr)
|
||||||
|
if err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return nil, ctx.Err()
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("smtp plain dial: %w", err)
|
||||||
|
}
|
||||||
|
_ = conn.SetDeadline(deadline)
|
||||||
|
client, err := smtp.NewClient(conn, host)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("smtp new client plain: %w", err)
|
||||||
|
}
|
||||||
|
return client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func dialStartTLS(ctx context.Context, addr string, host string, deadline time.Time) (*smtp.Client, error) {
|
||||||
|
client, err := dialPlain(ctx, addr, host, deadline)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if ok, _ := client.Extension("STARTTLS"); !ok {
|
||||||
|
_ = client.Close()
|
||||||
|
return nil, fmt.Errorf("smtp server does not support STARTTLS")
|
||||||
|
}
|
||||||
|
if err := client.StartTLS(&tls.Config{ServerName: host}); err != nil {
|
||||||
|
_ = client.Close()
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return nil, ctx.Err()
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("smtp starttls: %w", err)
|
||||||
|
}
|
||||||
|
return client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildMIMEMessage(from, to, subject, textBody, htmlBody string) string {
|
||||||
|
boundary := randomMIMEBoundary()
|
||||||
|
return strings.Join([]string{
|
||||||
|
fmt.Sprintf("From: %s", from),
|
||||||
|
fmt.Sprintf("To: %s", to),
|
||||||
|
fmt.Sprintf("Subject: %s", subject),
|
||||||
|
"MIME-Version: 1.0",
|
||||||
|
fmt.Sprintf("Content-Type: multipart/alternative; boundary=%s", boundary),
|
||||||
|
"",
|
||||||
|
fmt.Sprintf("--%s", boundary),
|
||||||
|
"Content-Type: text/plain; charset=UTF-8",
|
||||||
|
"",
|
||||||
|
textBody,
|
||||||
|
fmt.Sprintf("--%s", boundary),
|
||||||
|
"Content-Type: text/html; charset=UTF-8",
|
||||||
|
"",
|
||||||
|
htmlBody,
|
||||||
|
fmt.Sprintf("--%s--", boundary),
|
||||||
|
"",
|
||||||
|
}, "\r\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
func randomMIMEBoundary() string {
|
||||||
|
raw := make([]byte, 12)
|
||||||
|
if _, err := rand.Read(raw); err != nil {
|
||||||
|
return fmt.Sprintf("mime-boundary-%d", time.Now().UnixNano())
|
||||||
|
}
|
||||||
|
return "mime-boundary-" + hex.EncodeToString(raw)
|
||||||
|
}
|
||||||
126
smtp/smtp_mailer_additional_test.go
Normal file
126
smtp/smtp_mailer_additional_test.go
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
package smtp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSendReturnsConfigurationErrorWhenSMTPMissing(t *testing.T) {
|
||||||
|
mailer := NewSMTPMailer(SMTPConfig{Host: "", Port: 587, From: "noreply@example.com", Mode: SMTPModeTLS})
|
||||||
|
|
||||||
|
err := mailer.Send(context.Background(), "to@example.com", "subject", "<p>body</p>", "body")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected configuration error")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "smtp is not configured") {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSendReturnsDeadlineExceededForExpiredContext(t *testing.T) {
|
||||||
|
mailer := NewSMTPMailer(SMTPConfig{Host: "smtp.example.com", Port: 587, From: "noreply@example.com", Mode: SMTPModeTLS})
|
||||||
|
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-time.Second))
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
err := mailer.Send(ctx, "to@example.com", "subject", "<p>body</p>", "body")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected context deadline exceeded")
|
||||||
|
}
|
||||||
|
if err != context.DeadlineExceeded {
|
||||||
|
t.Fatalf("expected context deadline exceeded, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoginAuthNextHandlesChallenges(t *testing.T) {
|
||||||
|
auth := loginAuth{username: "alice", password: "s3cret"}
|
||||||
|
|
||||||
|
proto, initial, err := auth.Start(nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected start error: %v", err)
|
||||||
|
}
|
||||||
|
if proto != "LOGIN" {
|
||||||
|
t.Fatalf("expected LOGIN auth proto, got %q", proto)
|
||||||
|
}
|
||||||
|
if len(initial) != 0 {
|
||||||
|
t.Fatalf("expected empty initial response, got %q", string(initial))
|
||||||
|
}
|
||||||
|
|
||||||
|
value, err := auth.Next([]byte("Username:"), true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected username challenge error: %v", err)
|
||||||
|
}
|
||||||
|
if string(value) != "alice" {
|
||||||
|
t.Fatalf("expected username response alice, got %q", string(value))
|
||||||
|
}
|
||||||
|
|
||||||
|
value, err = auth.Next([]byte("UGFzc3dvcmQ6"), true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected password challenge error: %v", err)
|
||||||
|
}
|
||||||
|
if string(value) != "s3cret" {
|
||||||
|
t.Fatalf("expected password response s3cret, got %q", string(value))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoginAuthNextHandlesTerminalAndUnexpectedChallenge(t *testing.T) {
|
||||||
|
auth := loginAuth{username: "alice", password: "s3cret"}
|
||||||
|
|
||||||
|
value, err := auth.Next([]byte("ignored"), false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected nil error when more=false, got %v", err)
|
||||||
|
}
|
||||||
|
if value != nil {
|
||||||
|
t.Fatalf("expected nil value when more=false, got %q", string(value))
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := auth.Next([]byte("realm"), true); err == nil {
|
||||||
|
t.Fatal("expected error for unexpected login challenge")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildMIMEMessageContainsMultipartSections(t *testing.T) {
|
||||||
|
msg := buildMIMEMessage("from@example.com", "to@example.com", "Subject", "plain body", "<p>html body</p>")
|
||||||
|
|
||||||
|
checks := []string{
|
||||||
|
"From: from@example.com",
|
||||||
|
"To: to@example.com",
|
||||||
|
"Subject: Subject",
|
||||||
|
"MIME-Version: 1.0",
|
||||||
|
"Content-Type: text/plain; charset=UTF-8",
|
||||||
|
"plain body",
|
||||||
|
"Content-Type: text/html; charset=UTF-8",
|
||||||
|
"<p>html body</p>",
|
||||||
|
}
|
||||||
|
for _, snippet := range checks {
|
||||||
|
if !strings.Contains(msg, snippet) {
|
||||||
|
t.Fatalf("expected MIME message to contain %q", snippet)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !strings.Contains(msg, "\r\n") {
|
||||||
|
t.Fatal("expected CRLF separators in MIME message")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(msg, "Content-Type: multipart/alternative; boundary=mime-boundary-") {
|
||||||
|
t.Fatalf("expected random mime boundary header, got %q", msg)
|
||||||
|
}
|
||||||
|
if !strings.Contains(msg, "--mime-boundary-") {
|
||||||
|
t.Fatalf("expected mime boundary delimiters in message, got %q", msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDialHelpersReturnDeadlineExceededWhenDeadlineHasPassed(t *testing.T) {
|
||||||
|
deadline := time.Now().Add(-time.Second)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
if _, err := dialPlain(ctx, "smtp.example.com:25", "smtp.example.com", deadline); err != context.DeadlineExceeded {
|
||||||
|
t.Fatalf("dialPlain expected context deadline exceeded, got %v", err)
|
||||||
|
}
|
||||||
|
if _, err := dialSSL(ctx, "smtp.example.com:465", "smtp.example.com", deadline); err != context.DeadlineExceeded {
|
||||||
|
t.Fatalf("dialSSL expected context deadline exceeded, got %v", err)
|
||||||
|
}
|
||||||
|
if _, err := dialStartTLS(ctx, "smtp.example.com:587", "smtp.example.com", deadline); err != context.DeadlineExceeded {
|
||||||
|
t.Fatalf("dialStartTLS expected context deadline exceeded, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
424
smtp/smtp_mailer_integration_test.go
Normal file
424
smtp/smtp_mailer_integration_test.go
Normal file
@@ -0,0 +1,424 @@
|
|||||||
|
package smtp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type smtpTestServerConfig struct {
|
||||||
|
extensions []string
|
||||||
|
failAuth bool
|
||||||
|
failMail bool
|
||||||
|
failRcpt bool
|
||||||
|
failData bool
|
||||||
|
failQuit bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type smtpTestServerState struct {
|
||||||
|
authCalls int
|
||||||
|
mailFrom string
|
||||||
|
rcptTo string
|
||||||
|
data string
|
||||||
|
}
|
||||||
|
|
||||||
|
type smtpTestServer struct {
|
||||||
|
host string
|
||||||
|
port int
|
||||||
|
state *smtpTestServerState
|
||||||
|
stop func()
|
||||||
|
}
|
||||||
|
|
||||||
|
func requireSMTPIntegration(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping smtp integration test in short mode")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func startSMTPTestServer(t *testing.T, cfg smtpTestServerConfig) *smtpTestServer {
|
||||||
|
t.Helper()
|
||||||
|
requireSMTPIntegration(t)
|
||||||
|
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("listen smtp test server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
state := &smtpTestServerState{}
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
conn, err := ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
r := bufio.NewReader(conn)
|
||||||
|
w := bufio.NewWriter(conn)
|
||||||
|
write := func(line string) bool {
|
||||||
|
if _, err := w.WriteString(line + "\r\n"); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return w.Flush() == nil
|
||||||
|
}
|
||||||
|
if !write("220 localhost ESMTP ready") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
inData := false
|
||||||
|
var dataBuf strings.Builder
|
||||||
|
for {
|
||||||
|
line, err := r.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
line = strings.TrimRight(line, "\r\n")
|
||||||
|
|
||||||
|
if inData {
|
||||||
|
if line == "." {
|
||||||
|
state.data = dataBuf.String()
|
||||||
|
if !write("250 2.0.0 queued") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
inData = false
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
dataBuf.WriteString(line)
|
||||||
|
dataBuf.WriteString("\n")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.SplitN(line, " ", 2)
|
||||||
|
cmd := strings.ToUpper(parts[0])
|
||||||
|
arg := ""
|
||||||
|
if len(parts) > 1 {
|
||||||
|
arg = parts[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
switch cmd {
|
||||||
|
case "EHLO", "HELO":
|
||||||
|
lines := append([]string{"localhost"}, cfg.extensions...)
|
||||||
|
for i, ext := range lines {
|
||||||
|
prefix := "250-"
|
||||||
|
if i == len(lines)-1 {
|
||||||
|
prefix = "250 "
|
||||||
|
}
|
||||||
|
if !write(prefix + ext) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "AUTH":
|
||||||
|
state.authCalls++
|
||||||
|
if cfg.failAuth {
|
||||||
|
if !write("535 5.7.8 auth failed") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
authArg := strings.ToUpper(strings.TrimSpace(arg))
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(authArg, "PLAIN"):
|
||||||
|
if !write("235 2.7.0 authenticated") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case strings.HasPrefix(authArg, "LOGIN"):
|
||||||
|
if !write("334 VXNlcm5hbWU6") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, err := r.ReadString('\n'); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !write("334 UGFzc3dvcmQ6") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, err := r.ReadString('\n'); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !write("235 2.7.0 authenticated") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if !write("504 5.5.4 unsupported auth mechanism") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "MAIL":
|
||||||
|
state.mailFrom = arg
|
||||||
|
if cfg.failMail {
|
||||||
|
if !write("550 5.1.0 sender rejected") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !write("250 2.1.0 ok") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case "RCPT":
|
||||||
|
state.rcptTo = arg
|
||||||
|
if cfg.failRcpt {
|
||||||
|
if !write("550 5.1.1 recipient rejected") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !write("250 2.1.5 ok") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case "DATA":
|
||||||
|
if cfg.failData {
|
||||||
|
if !write("554 5.5.0 data rejected") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !write("354 end data with <CR><LF>.<CR><LF>") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
inData = true
|
||||||
|
dataBuf.Reset()
|
||||||
|
case "QUIT":
|
||||||
|
if cfg.failQuit {
|
||||||
|
_ = write("554 5.5.1 quit failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = write("221 2.0.0 bye")
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
if !write("250 2.0.0 ok") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
tcpAddr := ln.Addr().(*net.TCPAddr)
|
||||||
|
return &smtpTestServer{
|
||||||
|
host: "127.0.0.1",
|
||||||
|
port: tcpAddr.Port,
|
||||||
|
state: state,
|
||||||
|
stop: func() {
|
||||||
|
_ = ln.Close()
|
||||||
|
wg.Wait()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSendUnencryptedSuccess(t *testing.T) {
|
||||||
|
srv := startSMTPTestServer(t, smtpTestServerConfig{})
|
||||||
|
defer srv.stop()
|
||||||
|
|
||||||
|
mailer := NewSMTPMailer(SMTPConfig{
|
||||||
|
Host: srv.host,
|
||||||
|
Port: srv.port,
|
||||||
|
From: "noreply@example.com",
|
||||||
|
Mode: SMTPModeUnencrypted,
|
||||||
|
})
|
||||||
|
|
||||||
|
err := mailer.Send(context.Background(), "to@example.com", "subject", "<p>hello</p>", "hello")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Send: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(srv.state.mailFrom, "noreply@example.com") {
|
||||||
|
t.Fatalf("expected mail from to include sender, got %q", srv.state.mailFrom)
|
||||||
|
}
|
||||||
|
if !strings.Contains(srv.state.rcptTo, "to@example.com") {
|
||||||
|
t.Fatalf("expected rcpt to include recipient, got %q", srv.state.rcptTo)
|
||||||
|
}
|
||||||
|
if !strings.Contains(srv.state.data, "Subject: subject") {
|
||||||
|
t.Fatalf("expected message data to include subject, got %q", srv.state.data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSendUnencryptedErrorPaths(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
cfg smtpTestServerConfig
|
||||||
|
wantErr string
|
||||||
|
}{
|
||||||
|
{name: "mail from error", cfg: smtpTestServerConfig{failMail: true}, wantErr: "smtp mail from"},
|
||||||
|
{name: "rcpt error", cfg: smtpTestServerConfig{failRcpt: true}, wantErr: "smtp rcpt to"},
|
||||||
|
{name: "data error", cfg: smtpTestServerConfig{failData: true}, wantErr: "smtp data"},
|
||||||
|
{name: "quit error", cfg: smtpTestServerConfig{failQuit: true}, wantErr: "smtp quit"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
srv := startSMTPTestServer(t, tc.cfg)
|
||||||
|
defer srv.stop()
|
||||||
|
|
||||||
|
mailer := NewSMTPMailer(SMTPConfig{
|
||||||
|
Host: srv.host,
|
||||||
|
Port: srv.port,
|
||||||
|
From: "noreply@example.com",
|
||||||
|
Mode: SMTPModeUnencrypted,
|
||||||
|
})
|
||||||
|
err := mailer.Send(context.Background(), "to@example.com", "subject", "<p>hello</p>", "hello")
|
||||||
|
if err == nil || !strings.Contains(err.Error(), tc.wantErr) {
|
||||||
|
t.Fatalf("expected %q error, got %v", tc.wantErr, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSendAuthModes(t *testing.T) {
|
||||||
|
t.Run("auth none does not issue AUTH command", func(t *testing.T) {
|
||||||
|
srv := startSMTPTestServer(t, smtpTestServerConfig{extensions: []string{"AUTH PLAIN LOGIN"}})
|
||||||
|
defer srv.stop()
|
||||||
|
|
||||||
|
mailer := NewSMTPMailer(SMTPConfig{
|
||||||
|
Host: srv.host,
|
||||||
|
Port: srv.port,
|
||||||
|
From: "noreply@example.com",
|
||||||
|
Mode: SMTPModeUnencrypted,
|
||||||
|
Username: "alice",
|
||||||
|
Password: "secret",
|
||||||
|
Auth: "none",
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := mailer.Send(context.Background(), "to@example.com", "subject", "<p>hello</p>", "hello"); err != nil {
|
||||||
|
t.Fatalf("Send: %v", err)
|
||||||
|
}
|
||||||
|
if srv.state.authCalls != 0 {
|
||||||
|
t.Fatalf("expected no AUTH command, got %d", srv.state.authCalls)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("auth plain", func(t *testing.T) {
|
||||||
|
srv := startSMTPTestServer(t, smtpTestServerConfig{extensions: []string{"AUTH PLAIN"}})
|
||||||
|
defer srv.stop()
|
||||||
|
|
||||||
|
mailer := NewSMTPMailer(SMTPConfig{
|
||||||
|
Host: srv.host,
|
||||||
|
Port: srv.port,
|
||||||
|
From: "noreply@example.com",
|
||||||
|
Mode: SMTPModeUnencrypted,
|
||||||
|
Username: "alice",
|
||||||
|
Password: "secret",
|
||||||
|
Auth: "plain",
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := mailer.Send(context.Background(), "to@example.com", "subject", "<p>hello</p>", "hello"); err != nil {
|
||||||
|
t.Fatalf("Send: %v", err)
|
||||||
|
}
|
||||||
|
if srv.state.authCalls == 0 {
|
||||||
|
t.Fatal("expected AUTH command for plain auth")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("auth login", func(t *testing.T) {
|
||||||
|
srv := startSMTPTestServer(t, smtpTestServerConfig{extensions: []string{"AUTH LOGIN"}})
|
||||||
|
defer srv.stop()
|
||||||
|
|
||||||
|
mailer := NewSMTPMailer(SMTPConfig{
|
||||||
|
Host: srv.host,
|
||||||
|
Port: srv.port,
|
||||||
|
From: "noreply@example.com",
|
||||||
|
Mode: SMTPModeUnencrypted,
|
||||||
|
Username: "alice",
|
||||||
|
Password: "secret",
|
||||||
|
Auth: "login",
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := mailer.Send(context.Background(), "to@example.com", "subject", "<p>hello</p>", "hello"); err != nil {
|
||||||
|
t.Fatalf("Send: %v", err)
|
||||||
|
}
|
||||||
|
if srv.state.authCalls == 0 {
|
||||||
|
t.Fatal("expected AUTH command for login auth")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("auth auto failure", func(t *testing.T) {
|
||||||
|
srv := startSMTPTestServer(t, smtpTestServerConfig{failAuth: true})
|
||||||
|
defer srv.stop()
|
||||||
|
|
||||||
|
mailer := NewSMTPMailer(SMTPConfig{
|
||||||
|
Host: srv.host,
|
||||||
|
Port: srv.port,
|
||||||
|
From: "noreply@example.com",
|
||||||
|
Mode: SMTPModeUnencrypted,
|
||||||
|
Username: "alice",
|
||||||
|
Password: "secret",
|
||||||
|
Auth: "auto",
|
||||||
|
})
|
||||||
|
|
||||||
|
err := mailer.Send(context.Background(), "to@example.com", "subject", "<p>hello</p>", "hello")
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "smtp auth failed") {
|
||||||
|
t.Fatalf("expected smtp auth failed error, got %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unsupported auth method", func(t *testing.T) {
|
||||||
|
srv := startSMTPTestServer(t, smtpTestServerConfig{})
|
||||||
|
defer srv.stop()
|
||||||
|
|
||||||
|
mailer := NewSMTPMailer(SMTPConfig{
|
||||||
|
Host: srv.host,
|
||||||
|
Port: srv.port,
|
||||||
|
From: "noreply@example.com",
|
||||||
|
Mode: SMTPModeUnencrypted,
|
||||||
|
Username: "alice",
|
||||||
|
Password: "secret",
|
||||||
|
Auth: "weird",
|
||||||
|
})
|
||||||
|
|
||||||
|
err := mailer.Send(context.Background(), "to@example.com", "subject", "<p>hello</p>", "hello")
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "unsupported smtp auth method") {
|
||||||
|
t.Fatalf("expected unsupported auth method error, got %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSendTLSFailsWhenStartTLSExtensionMissing(t *testing.T) {
|
||||||
|
srv := startSMTPTestServer(t, smtpTestServerConfig{})
|
||||||
|
defer srv.stop()
|
||||||
|
|
||||||
|
mailer := NewSMTPMailer(SMTPConfig{
|
||||||
|
Host: srv.host,
|
||||||
|
Port: srv.port,
|
||||||
|
From: "noreply@example.com",
|
||||||
|
Mode: SMTPModeTLS,
|
||||||
|
})
|
||||||
|
|
||||||
|
err := mailer.Send(context.Background(), "to@example.com", "subject", "<p>hello</p>", "hello")
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "does not support STARTTLS") {
|
||||||
|
t.Fatalf("expected STARTTLS unsupported error, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDialHelpersWrapDialErrors(t *testing.T) {
|
||||||
|
requireSMTPIntegration(t)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
deadline := time.Now().Add(500 * time.Millisecond)
|
||||||
|
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("listen ephemeral: %v", err)
|
||||||
|
}
|
||||||
|
addr := ln.Addr().String()
|
||||||
|
_ = ln.Close()
|
||||||
|
host, portStr, err := net.SplitHostPort(addr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("split host port: %v", err)
|
||||||
|
}
|
||||||
|
port, err := strconv.Atoi(portStr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse port: %v", err)
|
||||||
|
}
|
||||||
|
unusedAddr := net.JoinHostPort(host, strconv.Itoa(port))
|
||||||
|
|
||||||
|
if _, err := dialPlain(ctx, unusedAddr, host, deadline); err == nil || !strings.Contains(err.Error(), "smtp plain dial") {
|
||||||
|
t.Fatalf("expected wrapped plain dial error, got %v", err)
|
||||||
|
}
|
||||||
|
if _, err := dialSSL(ctx, unusedAddr, host, deadline); err == nil || !strings.Contains(err.Error(), "smtp ssl dial") {
|
||||||
|
t.Fatalf("expected wrapped ssl dial error, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
47
smtp/smtp_mailer_test.go
Normal file
47
smtp/smtp_mailer_test.go
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
package smtp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSendHonorsCanceledContext(t *testing.T) {
|
||||||
|
mailer := NewSMTPMailer(SMTPConfig{
|
||||||
|
Host: "smtp.example.com",
|
||||||
|
Port: 587,
|
||||||
|
From: "noreply@example.com",
|
||||||
|
Mode: SMTPModeTLS,
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
err := mailer.Send(ctx, "to@example.com", "subject", "<p>hello</p>", "hello")
|
||||||
|
if !errors.Is(err, context.Canceled) {
|
||||||
|
t.Fatalf("expected %v, got %v", context.Canceled, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSMTPSendDeadlineUsesContextDeadline(t *testing.T) {
|
||||||
|
now := time.Date(2026, time.January, 1, 12, 0, 0, 0, time.UTC)
|
||||||
|
want := now.Add(2 * time.Minute)
|
||||||
|
ctx, cancel := context.WithDeadline(context.Background(), want)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
got := smtpSendDeadline(ctx, now)
|
||||||
|
if !got.Equal(want) {
|
||||||
|
t.Fatalf("expected context deadline %v, got %v", want, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSMTPSendDeadlineUsesDefaultWhenContextHasNoDeadline(t *testing.T) {
|
||||||
|
now := time.Date(2026, time.January, 1, 12, 0, 0, 0, time.UTC)
|
||||||
|
want := now.Add(defaultSMTPOperationTimeout)
|
||||||
|
|
||||||
|
got := smtpSendDeadline(context.Background(), now)
|
||||||
|
if !got.Equal(want) {
|
||||||
|
t.Fatalf("expected default deadline %v, got %v", want, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
70
worker/poller.go
Normal file
70
worker/poller.go
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
package worker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log/slog"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const defaultTaskName = "batch worker"
|
||||||
|
|
||||||
|
// BatchRunner executes one batch and returns the number of processed records.
|
||||||
|
// Returning 0 stops the current drain cycle.
|
||||||
|
type BatchRunner func(ctx context.Context, limit int) (int, error)
|
||||||
|
|
||||||
|
func Run(ctx context.Context, runner BatchRunner, interval time.Duration, batchSize int, logger *slog.Logger, taskName string) {
|
||||||
|
if runner == nil || interval <= 0 || batchSize <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if logger == nil {
|
||||||
|
logger = slog.Default()
|
||||||
|
}
|
||||||
|
if taskName == "" {
|
||||||
|
taskName = defaultTaskName
|
||||||
|
}
|
||||||
|
|
||||||
|
RunOnce(ctx, runner, batchSize, logger, taskName)
|
||||||
|
|
||||||
|
ticker := time.NewTicker(interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
RunOnce(ctx, runner, batchSize, logger, taskName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func RunOnce(ctx context.Context, runner BatchRunner, batchSize int, logger *slog.Logger, taskName string) {
|
||||||
|
if runner == nil || batchSize <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if logger == nil {
|
||||||
|
logger = slog.Default()
|
||||||
|
}
|
||||||
|
if taskName == "" {
|
||||||
|
taskName = defaultTaskName
|
||||||
|
}
|
||||||
|
|
||||||
|
totalProcessed := 0
|
||||||
|
for {
|
||||||
|
processed, err := runner(ctx, batchSize)
|
||||||
|
if err != nil {
|
||||||
|
if ctx.Err() == nil {
|
||||||
|
logger.Warn(taskName+" run failed", "error", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if processed <= 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
totalProcessed += processed
|
||||||
|
}
|
||||||
|
|
||||||
|
if totalProcessed > 0 {
|
||||||
|
logger.Info(taskName+" completed", "processed", totalProcessed)
|
||||||
|
}
|
||||||
|
}
|
||||||
125
worker/poller_test.go
Normal file
125
worker/poller_test.go
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
package worker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type fakeScheduledPostPromoter struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
calls int
|
||||||
|
limits []int
|
||||||
|
results []int
|
||||||
|
errAt int
|
||||||
|
callHook func(int)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeScheduledPostPromoter) PromoteDueScheduled(_ context.Context, limit int) (int, error) {
|
||||||
|
f.mu.Lock()
|
||||||
|
defer f.mu.Unlock()
|
||||||
|
|
||||||
|
f.calls++
|
||||||
|
callNo := f.calls
|
||||||
|
f.limits = append(f.limits, limit)
|
||||||
|
if f.callHook != nil {
|
||||||
|
f.callHook(callNo)
|
||||||
|
}
|
||||||
|
|
||||||
|
if f.errAt > 0 && callNo == f.errAt {
|
||||||
|
return 0, errors.New("boom")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(f.results) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result := f.results[0]
|
||||||
|
f.results = f.results[1:]
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeScheduledPostPromoter) Calls() int {
|
||||||
|
f.mu.Lock()
|
||||||
|
defer f.mu.Unlock()
|
||||||
|
return f.calls
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeScheduledPostPromoter) Limits() []int {
|
||||||
|
f.mu.Lock()
|
||||||
|
defer f.mu.Unlock()
|
||||||
|
out := make([]int, len(f.limits))
|
||||||
|
copy(out, f.limits)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDiscardLogger() *slog.Logger {
|
||||||
|
return slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPromoteDueScheduledOnceBatchesUntilEmpty(t *testing.T) {
|
||||||
|
repo := &fakeScheduledPostPromoter{results: []int{2, 3, 0}}
|
||||||
|
|
||||||
|
RunOnce(context.Background(), repo.PromoteDueScheduled, 50, newDiscardLogger(), "scheduled post promotion")
|
||||||
|
|
||||||
|
if got := repo.Calls(); got != 3 {
|
||||||
|
t.Fatalf("expected 3 calls, got %d", got)
|
||||||
|
}
|
||||||
|
limits := repo.Limits()
|
||||||
|
if len(limits) != 3 || limits[0] != 50 || limits[1] != 50 || limits[2] != 50 {
|
||||||
|
t.Fatalf("unexpected limits: %+v", limits)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPromoteDueScheduledOnceStopsOnError(t *testing.T) {
|
||||||
|
repo := &fakeScheduledPostPromoter{
|
||||||
|
results: []int{4, 4},
|
||||||
|
errAt: 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
RunOnce(context.Background(), repo.PromoteDueScheduled, 25, newDiscardLogger(), "scheduled post promotion")
|
||||||
|
|
||||||
|
if got := repo.Calls(); got != 2 {
|
||||||
|
t.Fatalf("expected 2 calls before stop, got %d", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunScheduledPostPromoterStopsOnContextCancel(t *testing.T) {
|
||||||
|
callCh := make(chan int, 8)
|
||||||
|
repo := &fakeScheduledPostPromoter{
|
||||||
|
results: []int{0, 0, 0, 0},
|
||||||
|
callHook: func(call int) {
|
||||||
|
callCh <- call
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
Run(ctx, repo.PromoteDueScheduled, 10*time.Millisecond, 10, newDiscardLogger(), "scheduled post promotion")
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
timeout := time.After(300 * time.Millisecond)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-callCh:
|
||||||
|
if repo.Calls() >= 2 {
|
||||||
|
cancel()
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
return
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
t.Fatal("promoter did not stop after context cancel")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case <-timeout:
|
||||||
|
cancel()
|
||||||
|
t.Fatal("timed out waiting for promoter calls")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user