From baa764befd7a72759ad12316cc6b90a23699c0a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beatrice=20Dellac=C3=A0?= Date: Sun, 1 Mar 2026 03:04:10 +0100 Subject: [PATCH] add core lib --- .drone.yml | 62 ++++ .gitignore | 2 + README.md | 12 + dbpool/postgres.go | 68 +++++ dbpool/postgres_test.go | 26 ++ dotenv/dotenv.go | 52 ++++ dotenv/dotenv_test.go | 76 +++++ go.mod | 17 ++ go.sum | 93 ++++++ middleware/cors.go | 67 +++++ middleware/cors_test.go | 137 +++++++++ middleware/rate_limit.go | 133 +++++++++ middleware/rate_limit_test.go | 177 +++++++++++ migrate/migrate.go | 150 ++++++++++ migrate/migrate_integration_test.go | 67 +++++ migrate/migrate_unit_test.go | 77 +++++ smtp/smtp_mailer.go | 345 ++++++++++++++++++++++ smtp/smtp_mailer_additional_test.go | 126 ++++++++ smtp/smtp_mailer_integration_test.go | 424 +++++++++++++++++++++++++++ smtp/smtp_mailer_test.go | 47 +++ worker/poller.go | 70 +++++ worker/poller_test.go | 125 ++++++++ 22 files changed, 2353 insertions(+) create mode 100644 .drone.yml create mode 100644 .gitignore create mode 100644 README.md create mode 100644 dbpool/postgres.go create mode 100644 dbpool/postgres_test.go create mode 100644 dotenv/dotenv.go create mode 100644 dotenv/dotenv_test.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 middleware/cors.go create mode 100644 middleware/cors_test.go create mode 100644 middleware/rate_limit.go create mode 100644 middleware/rate_limit_test.go create mode 100644 migrate/migrate.go create mode 100644 migrate/migrate_integration_test.go create mode 100644 migrate/migrate_unit_test.go create mode 100644 smtp/smtp_mailer.go create mode 100644 smtp/smtp_mailer_additional_test.go create mode 100644 smtp/smtp_mailer_integration_test.go create mode 100644 smtp/smtp_mailer_test.go create mode 100644 worker/poller.go create mode 100644 worker/poller_test.go diff --git a/.drone.yml b/.drone.yml new file mode 100644 index 0000000..df35714 --- /dev/null +++ b/.drone.yml @@ -0,0 +1,62 @@ +--- +kind: pipeline +type: docker +name: go-core-ci + +trigger: + branch: + - main + - develop + event: + - push + - pull_request + +environment: + GOPROXY: https://nexus.beatrice.wtf/repository/go-group/,direct + GOPRIVATE: git.beatrice.wtf/panic.haus/* + GONOSUMDB: git.beatrice.wtf/panic.haus/* + +steps: + - name: test + image: golang:1.26 + commands: + - go mod download + - go test ./... -count=1 + + - name: build + image: golang:1.26 + commands: + - go build ./... + +--- +kind: pipeline +type: docker +name: go-core-release + +trigger: + event: + - tag + ref: + - refs/tags/v* + +environment: + GOPROXY: https://nexus.beatrice.wtf/repository/go-group/,direct + GOPRIVATE: git.beatrice.wtf/panic.haus/* + GONOSUMDB: git.beatrice.wtf/panic.haus/* + +steps: + - name: test + image: golang:1.26 + commands: + - go mod download + - go test ./... -count=1 + + - name: build + image: golang:1.26 + commands: + - go build ./... + + - name: warm-proxy + image: golang:1.26 + commands: + - go list -m git.beatrice.wtf/panic.haus/go-core@${DRONE_TAG} diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..21e3b07 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +coverage.out +bin/ diff --git a/README.md b/README.md new file mode 100644 index 0000000..79db3cd --- /dev/null +++ b/README.md @@ -0,0 +1,12 @@ +# go-core + +Reusable backend infrastructure module. + +## Packages + +- `dotenv`: load dotenv candidates while preserving existing env vars. +- `dbpool`: PostgreSQL pool bootstrap for pgx. +- `migrate`: migration helpers around `golang-migrate`. +- `smtp`: SMTP mailer implementation. +- `middleware`: reusable HTTP middleware (`CORS`, IP rate limiter). +- `worker`: generic batch poller/runner utilities. diff --git a/dbpool/postgres.go b/dbpool/postgres.go new file mode 100644 index 0000000..dc8c070 --- /dev/null +++ b/dbpool/postgres.go @@ -0,0 +1,68 @@ +package dbpool + +import ( + "context" + "fmt" + "time" + + "github.com/jackc/pgx/v5/pgxpool" +) + +type PoolConfig struct { + MaxOpenConns int + MinIdleConns int + MaxConnLifetime time.Duration + MaxConnIdleTime time.Duration + HealthCheckPeriod time.Duration + ConnectionAcquireWait time.Duration +} + +func NewPostgresPool(ctx context.Context, databaseURL string, pool PoolConfig) (*pgxpool.Pool, error) { + cfg, err := pgxpool.ParseConfig(databaseURL) + if err != nil { + return nil, fmt.Errorf("parse postgres pool config: %w", err) + } + + if pool.MaxOpenConns <= 0 { + pool.MaxOpenConns = 25 + } + if pool.MinIdleConns < 0 { + pool.MinIdleConns = 0 + } + if pool.MinIdleConns > pool.MaxOpenConns { + pool.MinIdleConns = pool.MaxOpenConns + } + if pool.MaxConnLifetime <= 0 { + pool.MaxConnLifetime = 30 * time.Minute + } + if pool.MaxConnIdleTime <= 0 { + pool.MaxConnIdleTime = 5 * time.Minute + } + if pool.HealthCheckPeriod <= 0 { + pool.HealthCheckPeriod = time.Minute + } + if pool.ConnectionAcquireWait <= 0 { + pool.ConnectionAcquireWait = 10 * time.Second + } + + cfg.MaxConns = int32(pool.MaxOpenConns) + cfg.MinConns = int32(pool.MinIdleConns) + cfg.MaxConnLifetime = pool.MaxConnLifetime + cfg.MaxConnIdleTime = pool.MaxConnIdleTime + cfg.HealthCheckPeriod = pool.HealthCheckPeriod + + pingCtx, cancel := context.WithTimeout(ctx, pool.ConnectionAcquireWait) + defer cancel() + + poolConn, err := pgxpool.NewWithConfig(ctx, cfg) + if err != nil { + return nil, fmt.Errorf("open postgres pool: %w", err) + } + + if err := poolConn.Ping(pingCtx); err != nil { + poolConn.Close() + return nil, fmt.Errorf("ping postgres: %w", err) + } + + return poolConn, nil +} diff --git a/dbpool/postgres_test.go b/dbpool/postgres_test.go new file mode 100644 index 0000000..59e2682 --- /dev/null +++ b/dbpool/postgres_test.go @@ -0,0 +1,26 @@ +package dbpool + +import ( + "context" + "strings" + "testing" + "time" +) + +func TestNewPostgresPool_ParseConfigError(t *testing.T) { + _, err := NewPostgresPool(context.Background(), "://invalid-url", PoolConfig{}) + if err == nil || !strings.Contains(err.Error(), "parse postgres pool config") { + t.Fatalf("expected parse config error, got %v", err) + } +} + +func TestNewPostgresPool_PingError(t *testing.T) { + _, err := NewPostgresPool( + context.Background(), + "postgres://postgres:postgres@127.0.0.1:1/appdb?sslmode=disable", + PoolConfig{ConnectionAcquireWait: 20 * time.Millisecond}, + ) + if err == nil || !strings.Contains(err.Error(), "ping postgres") { + t.Fatalf("expected ping error, got %v", err) + } +} diff --git a/dotenv/dotenv.go b/dotenv/dotenv.go new file mode 100644 index 0000000..1308ca9 --- /dev/null +++ b/dotenv/dotenv.go @@ -0,0 +1,52 @@ +package config + +import ( + "bufio" + "os" + "strings" +) + +// LoadDotEnvCandidates loads KEY=VALUE pairs from the first existing file in paths. +// Existing process env vars are preserved (file values only fill missing keys). +func LoadDotEnvCandidates(paths []string) { + for _, path := range paths { + if loadDotEnvFile(path) { + return + } + } +} + +func loadDotEnvFile(path string) bool { + file, err := os.Open(path) + if err != nil { + return false + } + defer file.Close() + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + + key, value, ok := strings.Cut(line, "=") + if !ok { + continue + } + + key = strings.TrimSpace(key) + value = strings.TrimSpace(value) + value = strings.Trim(value, `"'`) + if key == "" { + continue + } + + if _, exists := os.LookupEnv(key); exists { + continue + } + _ = os.Setenv(key, value) + } + + return true +} diff --git a/dotenv/dotenv_test.go b/dotenv/dotenv_test.go new file mode 100644 index 0000000..2216b5b --- /dev/null +++ b/dotenv/dotenv_test.go @@ -0,0 +1,76 @@ +package config + +import ( + "os" + "path/filepath" + "testing" +) + +func TestLoadDotEnvFileParsesAndPreservesExistingValues(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, ".env") + content := "" + + "# comment\n" + + "KEY1 = value1\n" + + "KEY2='quoted value'\n" + + "KEY3=\"double quoted\"\n" + + "INVALID_LINE\n" + + "=missing_key\n" + + if err := os.WriteFile(path, []byte(content), 0o600); err != nil { + t.Fatalf("write env file: %v", err) + } + + t.Setenv("KEY1", "existing") + _ = os.Unsetenv("KEY2") + _ = os.Unsetenv("KEY3") + + if ok := loadDotEnvFile(path); !ok { + t.Fatal("expected loadDotEnvFile to return true") + } + + if got := os.Getenv("KEY1"); got != "existing" { + t.Fatalf("expected KEY1 to preserve existing value, got %q", got) + } + if got := os.Getenv("KEY2"); got != "quoted value" { + t.Fatalf("expected KEY2 parsed value, got %q", got) + } + if got := os.Getenv("KEY3"); got != "double quoted" { + t.Fatalf("expected KEY3 parsed value, got %q", got) + } +} + +func TestLoadDotEnvCandidatesStopsAtFirstExistingFile(t *testing.T) { + dir := t.TempDir() + first := filepath.Join(dir, "first.env") + second := filepath.Join(dir, "second.env") + + if err := os.WriteFile(first, []byte("KEY_A=first\nKEY_B=from_first\n"), 0o600); err != nil { + t.Fatalf("write first env file: %v", err) + } + if err := os.WriteFile(second, []byte("KEY_A=second\nKEY_C=from_second\n"), 0o600); err != nil { + t.Fatalf("write second env file: %v", err) + } + + _ = os.Unsetenv("KEY_A") + _ = os.Unsetenv("KEY_B") + _ = os.Unsetenv("KEY_C") + + LoadDotEnvCandidates([]string{filepath.Join(dir, "missing.env"), first, second}) + + if got := os.Getenv("KEY_A"); got != "first" { + t.Fatalf("expected KEY_A from first file, got %q", got) + } + if got := os.Getenv("KEY_B"); got != "from_first" { + t.Fatalf("expected KEY_B from first file, got %q", got) + } + if got := os.Getenv("KEY_C"); got != "" { + t.Fatalf("expected KEY_C to remain unset, got %q", got) + } +} + +func TestLoadDotEnvFileReturnsFalseWhenMissing(t *testing.T) { + if ok := loadDotEnvFile(filepath.Join(t.TempDir(), "does-not-exist.env")); ok { + t.Fatal("expected loadDotEnvFile to return false for missing file") + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..68338a2 --- /dev/null +++ b/go.mod @@ -0,0 +1,17 @@ +module git.beatrice.wtf/panic.haus/go-core + +go 1.25 + +require ( + github.com/golang-migrate/migrate/v4 v4.19.1 + github.com/jackc/pgx/v5 v5.8.0 +) + +require ( + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect + github.com/lib/pq v1.10.9 // indirect + golang.org/x/sync v0.18.0 // indirect + golang.org/x/text v0.31.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..1890f75 --- /dev/null +++ b/go.sum @@ -0,0 +1,93 @@ +github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0= +github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= +github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= +github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= +github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI= +github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M= +github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE= +github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dhui/dktest v0.4.6 h1:+DPKyScKSEp3VLtbMDHcUq6V5Lm5zfZZVb0Sk7Ahom4= +github.com/dhui/dktest v0.4.6/go.mod h1:JHTSYDtKkvFNFHJKqCzVzqXecyv+tKt8EzceOmQOgbU= +github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk= +github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= +github.com/docker/docker v28.3.3+incompatible h1:Dypm25kh4rmk49v1eiVbsAtpAsYURjYkaKubwuBdxEI= +github.com/docker/docker v28.3.3+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= +github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c= +github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc= +github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= +github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= +github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= +github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= +github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug= +github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= +github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY= +github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= +github.com/golang-migrate/migrate/v4 v4.19.1 h1:OCyb44lFuQfYXYLx1SCxPZQGU7mcaZ7gH9yH4jSFbBA= +github.com/golang-migrate/migrate/v4 v4.19.1/go.mod h1:CTcgfjxhaUtsLipnLoQRWCrjYXycRz/g5+RWDuYgPrE= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo= +github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= +github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo= +github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0= +github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y= +github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= +github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= +github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= +github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= +github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug= +github.com/opencontainers/image-spec v1.1.0/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= +go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 h1:F7Jx+6hwnZ41NSFTO5q4LYDtJRXBf2PD0rNBkeB/lus= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0/go.mod h1:UHB22Z8QsdRDrnAtX4PntOl36ajSxcdUMt1sF7Y6E7Q= +go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ= +go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I= +go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/WgbsdpcPoZE= +go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E= +go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4= +go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0= +golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= +golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= +golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= +golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= +golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= +golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= +golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= +golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/middleware/cors.go b/middleware/cors.go new file mode 100644 index 0000000..b3a2018 --- /dev/null +++ b/middleware/cors.go @@ -0,0 +1,67 @@ +package middleware + +import ( + "net" + "net/http" + "net/url" + "strings" +) + +type CORSMiddleware struct { + allowedOrigins map[string]struct{} + allowAll bool +} + +func NewCORSMiddleware(origins []string) *CORSMiddleware { + m := &CORSMiddleware{allowedOrigins: map[string]struct{}{}} + for _, origin := range origins { + if origin == "*" { + m.allowAll = true + continue + } + m.allowedOrigins[origin] = struct{}{} + } + return m +} + +func (m *CORSMiddleware) Handler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + origin := r.Header.Get("Origin") + if origin != "" && (m.allowAll || m.isAllowed(origin) || isSameHost(origin, r.Host)) { + w.Header().Set("Access-Control-Allow-Origin", origin) + w.Header().Set("Vary", "Origin") + w.Header().Set("Access-Control-Allow-Credentials", "true") + w.Header().Set("Access-Control-Allow-Methods", "GET,POST,PUT,PATCH,DELETE,OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type,Authorization") + } + + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + + next.ServeHTTP(w, r) + }) +} + +func (m *CORSMiddleware) isAllowed(origin string) bool { + _, ok := m.allowedOrigins[origin] + return ok +} + +func isSameHost(origin string, requestHost string) bool { + parsed, err := url.Parse(origin) + if err != nil || parsed.Host == "" { + return false + } + + originHost := parsed.Hostname() + reqHost := requestHost + if strings.Contains(requestHost, ":") { + if host, _, splitErr := net.SplitHostPort(requestHost); splitErr == nil { + reqHost = host + } + } + + return strings.EqualFold(originHost, reqHost) +} diff --git a/middleware/cors_test.go b/middleware/cors_test.go new file mode 100644 index 0000000..6215898 --- /dev/null +++ b/middleware/cors_test.go @@ -0,0 +1,137 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestCORSHandlerAllowsConfiguredOrigin(t *testing.T) { + mw := NewCORSMiddleware([]string{"https://allowed.example"}) + nextCalled := false + h := mw.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + nextCalled = true + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Origin", "https://allowed.example") + rr := httptest.NewRecorder() + h.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", rr.Code) + } + if rr.Header().Get("Access-Control-Allow-Origin") != "https://allowed.example" { + t.Fatalf("unexpected allow-origin header: %q", rr.Header().Get("Access-Control-Allow-Origin")) + } + if rr.Header().Get("Access-Control-Allow-Credentials") != "true" { + t.Fatalf("expected allow-credentials true, got %q", rr.Header().Get("Access-Control-Allow-Credentials")) + } + if !nextCalled { + t.Fatal("expected next handler to be called") + } +} + +func TestCORSHandlerAllowsAnyOriginWhenWildcardConfigured(t *testing.T) { + mw := NewCORSMiddleware([]string{"*"}) + h := mw.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Origin", "https://random.example") + rr := httptest.NewRecorder() + h.ServeHTTP(rr, req) + + if rr.Header().Get("Access-Control-Allow-Origin") != "https://random.example" { + t.Fatalf("expected wildcard middleware to echo origin, got %q", rr.Header().Get("Access-Control-Allow-Origin")) + } + if rr.Header().Get("Access-Control-Allow-Credentials") != "true" { + t.Fatalf("expected allow-credentials true, got %q", rr.Header().Get("Access-Control-Allow-Credentials")) + } +} + +func TestCORSHandlerAllowsSameHostOrigin(t *testing.T) { + mw := NewCORSMiddleware(nil) + h := mw.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "http://api.example:8080/path", nil) + req.Host = "api.example:8080" + req.Header.Set("Origin", "https://api.example") + rr := httptest.NewRecorder() + h.ServeHTTP(rr, req) + + if rr.Header().Get("Access-Control-Allow-Origin") != "https://api.example" { + t.Fatalf("expected same-host origin to be allowed, got %q", rr.Header().Get("Access-Control-Allow-Origin")) + } + if rr.Header().Get("Access-Control-Allow-Credentials") != "true" { + t.Fatalf("expected allow-credentials true, got %q", rr.Header().Get("Access-Control-Allow-Credentials")) + } +} + +func TestCORSHandlerDoesNotSetHeadersForDisallowedOrigin(t *testing.T) { + mw := NewCORSMiddleware([]string{"https://allowed.example"}) + h := mw.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Origin", "https://blocked.example") + rr := httptest.NewRecorder() + h.ServeHTTP(rr, req) + + if rr.Header().Get("Access-Control-Allow-Origin") != "" { + t.Fatalf("expected no allow-origin header, got %q", rr.Header().Get("Access-Control-Allow-Origin")) + } +} + +func TestCORSHandlerOptionsShortCircuits(t *testing.T) { + mw := NewCORSMiddleware([]string{"https://allowed.example"}) + nextCalled := false + h := mw.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + nextCalled = true + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodOptions, "/", nil) + req.Header.Set("Origin", "https://allowed.example") + rr := httptest.NewRecorder() + h.ServeHTTP(rr, req) + + if rr.Code != http.StatusNoContent { + t.Fatalf("expected 204, got %d", rr.Code) + } + if rr.Header().Get("Access-Control-Allow-Credentials") != "true" { + t.Fatalf("expected allow-credentials true, got %q", rr.Header().Get("Access-Control-Allow-Credentials")) + } + if nextCalled { + t.Fatal("expected next handler not to be called for OPTIONS") + } +} + +func TestIsSameHost(t *testing.T) { + tests := []struct { + name string + origin string + requestHost string + want bool + }{ + {name: "same host exact", origin: "https://api.example", requestHost: "api.example", want: true}, + {name: "same host with port", origin: "https://api.example", requestHost: "api.example:8080", want: true}, + {name: "case insensitive", origin: "https://API.EXAMPLE", requestHost: "api.example", want: true}, + {name: "different host", origin: "https://api.example", requestHost: "other.example", want: false}, + {name: "invalid origin", origin: "://bad", requestHost: "api.example", want: false}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := isSameHost(tc.origin, tc.requestHost) + if got != tc.want { + t.Fatalf("isSameHost(%q, %q) = %v, want %v", tc.origin, tc.requestHost, got, tc.want) + } + }) + } +} diff --git a/middleware/rate_limit.go b/middleware/rate_limit.go new file mode 100644 index 0000000..6e7e270 --- /dev/null +++ b/middleware/rate_limit.go @@ -0,0 +1,133 @@ +package middleware + +import ( + "net" + "net/http" + "strconv" + "strings" + "sync" + "time" +) + +type ipWindowCounter struct { + windowStart time.Time + lastSeen time.Time + count int +} + +type IPRateLimiter struct { + mu sync.Mutex + limit int + window time.Duration + ttl time.Duration + maxEntries int + lastCleanup time.Time + entries map[string]*ipWindowCounter +} + +const defaultRateLimiterMaxEntries = 10000 + +func NewIPRateLimiter(limit int, window time.Duration, ttl time.Duration) *IPRateLimiter { + if limit <= 0 { + limit = 60 + } + if window <= 0 { + window = time.Minute + } + if ttl <= 0 { + ttl = 10 * time.Minute + } + return &IPRateLimiter{ + limit: limit, + window: window, + ttl: ttl, + maxEntries: defaultRateLimiterMaxEntries, + entries: map[string]*ipWindowCounter{}, + } +} + +func (m *IPRateLimiter) Handler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !m.allow(r) { + w.Header().Set("Retry-After", strconv.Itoa(int(m.window.Seconds()))) + http.Error(w, "rate limit exceeded", http.StatusTooManyRequests) + return + } + next.ServeHTTP(w, r) + }) +} + +func (m *IPRateLimiter) allow(r *http.Request) bool { + key := clientKey(r) + now := time.Now() + + m.mu.Lock() + defer m.mu.Unlock() + + if m.lastCleanup.IsZero() || now.Sub(m.lastCleanup) >= m.window { + m.cleanupLocked(now) + m.lastCleanup = now + } + + entry, exists := m.entries[key] + if !exists { + if len(m.entries) >= m.maxEntries { + m.evictOldestLocked() + } + m.entries[key] = &ipWindowCounter{ + windowStart: now, + lastSeen: now, + count: 1, + } + return true + } + + if now.Sub(entry.windowStart) >= m.window { + entry.windowStart = now + entry.lastSeen = now + entry.count = 1 + return true + } + + entry.lastSeen = now + if entry.count >= m.limit { + return false + } + entry.count++ + return true +} + +func (m *IPRateLimiter) cleanupLocked(now time.Time) { + for key, entry := range m.entries { + if now.Sub(entry.lastSeen) > m.ttl { + delete(m.entries, key) + } + } +} + +func (m *IPRateLimiter) evictOldestLocked() { + var oldestKey string + var oldestTime time.Time + first := true + for key, entry := range m.entries { + if first || entry.lastSeen.Before(oldestTime) { + oldestKey = key + oldestTime = entry.lastSeen + first = false + } + } + if !first { + delete(m.entries, oldestKey) + } +} + +func clientKey(r *http.Request) string { + host := strings.TrimSpace(r.RemoteAddr) + if parsedHost, _, err := net.SplitHostPort(host); err == nil { + host = parsedHost + } + if host == "" { + host = "unknown" + } + return host +} diff --git a/middleware/rate_limit_test.go b/middleware/rate_limit_test.go new file mode 100644 index 0000000..f439cb7 --- /dev/null +++ b/middleware/rate_limit_test.go @@ -0,0 +1,177 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestIPRateLimiterRejectsExcessRequestsAcrossAuthPaths(t *testing.T) { + limiter := NewIPRateLimiter(2, time.Minute, 5*time.Minute) + h := limiter.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + })) + + for i := 0; i < 2; i++ { + req := httptest.NewRequest(http.MethodPost, "/auth/login", nil) + req.RemoteAddr = "203.0.113.10:12345" + rr := httptest.NewRecorder() + h.ServeHTTP(rr, req) + if rr.Code != http.StatusNoContent { + t.Fatalf("request %d expected 204, got %d", i+1, rr.Code) + } + } + + third := httptest.NewRequest(http.MethodPost, "/auth/login", nil) + third.RemoteAddr = "203.0.113.10:12345" + rr := httptest.NewRecorder() + h.ServeHTTP(rr, third) + if rr.Code != http.StatusTooManyRequests { + t.Fatalf("expected 429 for third request, got %d", rr.Code) + } + if rr.Header().Get("Retry-After") == "" { + t.Fatal("expected Retry-After header") + } + + otherPath := httptest.NewRequest(http.MethodPost, "/auth/register", nil) + otherPath.RemoteAddr = "203.0.113.10:12345" + rr2 := httptest.NewRecorder() + h.ServeHTTP(rr2, otherPath) + if rr2.Code != http.StatusTooManyRequests { + t.Fatalf("expected shared limiter bucket per IP across paths, got %d", rr2.Code) + } +} + +func TestIPRateLimiterCleanupRunsPeriodicallyWithoutEntryThreshold(t *testing.T) { + limiter := NewIPRateLimiter(5, 20*time.Millisecond, 20*time.Millisecond) + + first := httptest.NewRequest(http.MethodPost, "/auth/login", nil) + first.RemoteAddr = "203.0.113.10:12345" + if !limiter.allow(first) { + t.Fatal("expected first request to be allowed") + } + + time.Sleep(50 * time.Millisecond) + + second := httptest.NewRequest(http.MethodPost, "/auth/login", nil) + second.RemoteAddr = "203.0.113.11:12345" + if !limiter.allow(second) { + t.Fatal("expected second request to be allowed") + } + + limiter.mu.Lock() + defer limiter.mu.Unlock() + if len(limiter.entries) != 1 { + t.Fatalf("expected stale entries to be cleaned, got %d entries", len(limiter.entries)) + } + if _, exists := limiter.entries["203.0.113.11"]; !exists { + t.Fatal("expected current client key to remain after cleanup") + } +} + +func TestIPRateLimiterEvictsOldestEntryAtCapacity(t *testing.T) { + limiter := NewIPRateLimiter(5, time.Minute, time.Hour) + limiter.maxEntries = 2 + + req1 := httptest.NewRequest(http.MethodPost, "/auth/login", nil) + req1.RemoteAddr = "203.0.113.10:12345" + if !limiter.allow(req1) { + t.Fatal("expected first request to be allowed") + } + + time.Sleep(2 * time.Millisecond) + + req2 := httptest.NewRequest(http.MethodPost, "/auth/login", nil) + req2.RemoteAddr = "203.0.113.11:12345" + if !limiter.allow(req2) { + t.Fatal("expected second request to be allowed") + } + + time.Sleep(2 * time.Millisecond) + + req3 := httptest.NewRequest(http.MethodPost, "/auth/login", nil) + req3.RemoteAddr = "203.0.113.12:12345" + if !limiter.allow(req3) { + t.Fatal("expected third request to be allowed") + } + + limiter.mu.Lock() + defer limiter.mu.Unlock() + if len(limiter.entries) != 2 { + t.Fatalf("expected limiter size to remain capped at 2, got %d", len(limiter.entries)) + } + if _, exists := limiter.entries["203.0.113.10"]; exists { + t.Fatal("expected oldest entry to be evicted") + } + if _, exists := limiter.entries["203.0.113.11"]; !exists { + t.Fatal("expected second entry to remain") + } + if _, exists := limiter.entries["203.0.113.12"]; !exists { + t.Fatal("expected newest entry to remain") + } +} + +func TestIPRateLimiterWindowResetAllowsRequestAgain(t *testing.T) { + limiter := NewIPRateLimiter(1, 20*time.Millisecond, time.Minute) + + req := httptest.NewRequest(http.MethodPost, "/auth/login", nil) + req.RemoteAddr = "203.0.113.10:12345" + if !limiter.allow(req) { + t.Fatal("expected first request to be allowed") + } + if limiter.allow(req) { + t.Fatal("expected second request in same window to be blocked") + } + + time.Sleep(25 * time.Millisecond) + if !limiter.allow(req) { + t.Fatal("expected request to be allowed after window reset") + } +} + +func TestClientKeyVariants(t *testing.T) { + t.Run("host port", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.RemoteAddr = "203.0.113.10:4321" + if got := clientKey(req); got != "203.0.113.10" { + t.Fatalf("expected parsed host, got %q", got) + } + }) + + t.Run("raw host", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.RemoteAddr = "203.0.113.11" + if got := clientKey(req); got != "203.0.113.11" { + t.Fatalf("expected raw host, got %q", got) + } + }) + + t.Run("empty uses unknown", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.RemoteAddr = " " + if got := clientKey(req); got != "unknown" { + t.Fatalf("expected unknown fallback, got %q", got) + } + }) +} + +func TestNewIPRateLimiterAppliesDefaults(t *testing.T) { + limiter := NewIPRateLimiter(0, 0, 0) + + if limiter.limit != 60 { + t.Fatalf("expected default limit 60, got %d", limiter.limit) + } + if limiter.window != time.Minute { + t.Fatalf("expected default window %v, got %v", time.Minute, limiter.window) + } + if limiter.ttl != 10*time.Minute { + t.Fatalf("expected default ttl %v, got %v", 10*time.Minute, limiter.ttl) + } + if limiter.maxEntries != defaultRateLimiterMaxEntries { + t.Fatalf("expected maxEntries %d, got %d", defaultRateLimiterMaxEntries, limiter.maxEntries) + } + if limiter.entries == nil { + t.Fatal("expected entries map to be initialized") + } +} diff --git a/migrate/migrate.go b/migrate/migrate.go new file mode 100644 index 0000000..f5561cb --- /dev/null +++ b/migrate/migrate.go @@ -0,0 +1,150 @@ +package migrate + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "time" + + "github.com/golang-migrate/migrate/v4" + _ "github.com/golang-migrate/migrate/v4/database/postgres" + _ "github.com/golang-migrate/migrate/v4/source/file" +) + +const ( + defaultMigrationsPath = "file://migrations" +) + +type MigrationConfig struct { + Path string + LockTimeout time.Duration + StartupTimeout time.Duration +} + +type MigrationStatus struct { + Version uint + Dirty bool +} + +func RunMigrationsUp(databaseURL string, cfg MigrationConfig) error { + exec := func() error { + m, err := newMigrator(databaseURL, cfg.Path) + if err != nil { + return err + } + defer closeMigrator(m) + + if err := m.Up(); err != nil && !errors.Is(err, migrate.ErrNoChange) { + return fmt.Errorf("run migrations up: %w", err) + } + return nil + } + + return runWithTimeout(cfg.StartupTimeout, exec) +} + +func RunMigrationsDown(databaseURL string, cfg MigrationConfig, steps int) error { + m, err := newMigrator(databaseURL, cfg.Path) + if err != nil { + return err + } + defer closeMigrator(m) + + if steps > 0 { + if err := m.Steps(-steps); err != nil && !errors.Is(err, migrate.ErrNoChange) { + return fmt.Errorf("run migrations down %d steps: %w", steps, err) + } + return nil + } + + if err := m.Down(); err != nil && !errors.Is(err, migrate.ErrNoChange) { + return fmt.Errorf("run migrations down all: %w", err) + } + return nil +} + +func ForceMigrationVersion(databaseURL string, cfg MigrationConfig, version int) error { + m, err := newMigrator(databaseURL, cfg.Path) + if err != nil { + return err + } + defer closeMigrator(m) + + if err := m.Force(version); err != nil { + return fmt.Errorf("force migration version %d: %w", version, err) + } + return nil +} + +func GetMigrationStatus(databaseURL string, cfg MigrationConfig) (MigrationStatus, error) { + m, err := newMigrator(databaseURL, cfg.Path) + if err != nil { + return MigrationStatus{}, err + } + defer closeMigrator(m) + + version, dirty, err := m.Version() + if errors.Is(err, migrate.ErrNilVersion) { + return MigrationStatus{}, nil + } + if err != nil { + return MigrationStatus{}, fmt.Errorf("read migration version: %w", err) + } + + return MigrationStatus{Version: version, Dirty: dirty}, nil +} + +func ResolveMigrationsPath(pathValue string) string { + if pathValue == "" { + pathValue = defaultMigrationsPath + } + if filepath.IsAbs(pathValue) { + return "file://" + pathValue + } + if len(pathValue) >= len("file://") && pathValue[:len("file://")] == "file://" { + return pathValue + } + wd, err := os.Getwd() + if err != nil { + return defaultMigrationsPath + } + return "file://" + filepath.Join(wd, pathValue) +} + +func newMigrator(databaseURL string, path string) (*migrate.Migrate, error) { + migrationsPath := ResolveMigrationsPath(path) + m, err := migrate.New(migrationsPath, databaseURL) + if err != nil { + return nil, fmt.Errorf("open migration driver (%s): %w", migrationsPath, err) + } + return m, nil +} + +func closeMigrator(m *migrate.Migrate) { + if m == nil { + return + } + sourceErr, dbErr := m.Close() + if sourceErr != nil || dbErr != nil { + _ = sourceErr + _ = dbErr + } +} + +func runWithTimeout(timeout time.Duration, fn func() error) error { + if timeout <= 0 { + return fn() + } + done := make(chan error, 1) + go func() { + done <- fn() + }() + + select { + case err := <-done: + return err + case <-time.After(timeout): + return fmt.Errorf("migration timed out after %s", timeout) + } +} diff --git a/migrate/migrate_integration_test.go b/migrate/migrate_integration_test.go new file mode 100644 index 0000000..21bdd26 --- /dev/null +++ b/migrate/migrate_integration_test.go @@ -0,0 +1,67 @@ +package migrate + +import ( + "os" + "path/filepath" + "runtime" + "testing" +) + +func TestRunMigrationsUpTwice(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + databaseURL := os.Getenv("INTEGRATION_DB_URL") + if databaseURL == "" { + t.Skip("set INTEGRATION_DB_URL to run migration integration tests") + } + + cfg := MigrationConfig{Path: resolveMigrationsPath()} + if err := RunMigrationsUp(databaseURL, cfg); err != nil { + t.Fatalf("first migrate up: %v", err) + } + if err := RunMigrationsUp(databaseURL, cfg); err != nil { + t.Fatalf("second migrate up: %v", err) + } + + status, err := GetMigrationStatus(databaseURL, cfg) + if err != nil { + t.Fatalf("get migration status: %v", err) + } + if status.Dirty { + t.Fatalf("expected non-dirty migration status") + } + if status.Version == 0 { + t.Fatalf("expected migration version > 0 after up") + } +} + +func resolveMigrationsPath() string { + configured := os.Getenv("DB_MIGRATIONS_PATH") + if configured == "" { + configured = "migrations" + } + + if filepath.IsAbs(configured) { + return configured + } + if len(configured) >= len("file://") && configured[:len("file://")] == "file://" { + return configured + } + if _, err := os.Stat(configured); err == nil { + return configured + } + + _, thisFile, _, ok := runtime.Caller(0) + if !ok { + return configured + } + moduleRoot := filepath.Clean(filepath.Join(filepath.Dir(thisFile), "..", "..")) + candidate := filepath.Join(moduleRoot, configured) + if _, err := os.Stat(candidate); err == nil { + return candidate + } + + return configured +} diff --git a/migrate/migrate_unit_test.go b/migrate/migrate_unit_test.go new file mode 100644 index 0000000..ca84296 --- /dev/null +++ b/migrate/migrate_unit_test.go @@ -0,0 +1,77 @@ +package migrate + +import ( + "errors" + "os" + "path/filepath" + "strings" + "testing" + "time" +) + +func TestResolveMigrationsPath(t *testing.T) { + if got := ResolveMigrationsPath(""); got != "file://migrations" && !strings.HasSuffix(got, "/migrations") { + t.Fatalf("expected default migrations path, got %q", got) + } + + if got := ResolveMigrationsPath("file:///tmp/migs"); got != "file:///tmp/migs" { + t.Fatalf("expected passthrough file URI, got %q", got) + } + + abs := filepath.Join(string(os.PathSeparator), "tmp", "migs") + if got := ResolveMigrationsPath(abs); got != "file://"+abs { + t.Fatalf("expected absolute path conversion, got %q", got) + } + + rel := "migrations" + got := ResolveMigrationsPath(rel) + if !strings.HasPrefix(got, "file://") || !strings.HasSuffix(got, rel) { + t.Fatalf("expected file URI ending with %q, got %q", rel, got) + } +} + +func TestRunWithTimeout(t *testing.T) { + if err := runWithTimeout(0, func() error { return nil }); err != nil { + t.Fatalf("expected no error for immediate execution, got %v", err) + } + + if err := runWithTimeout(50*time.Millisecond, func() error { + time.Sleep(5 * time.Millisecond) + return nil + }); err != nil { + t.Fatalf("expected fn to finish before timeout, got %v", err) + } + + err := runWithTimeout(10*time.Millisecond, func() error { + time.Sleep(50 * time.Millisecond) + return nil + }) + if err == nil || !strings.Contains(err.Error(), "timed out") { + t.Fatalf("expected timeout error, got %v", err) + } + + expected := errors.New("boom") + err = runWithTimeout(20*time.Millisecond, func() error { return expected }) + if !errors.Is(err, expected) { + t.Fatalf("expected function error propagation, got %v", err) + } +} + +func TestMigrationAPIs_InvalidSourcePath(t *testing.T) { + cfg := MigrationConfig{Path: "definitely-missing-migrations-dir"} + + if err := RunMigrationsUp("postgres://ignored", cfg); err == nil { + t.Fatalf("expected RunMigrationsUp error for missing source path") + } + if err := RunMigrationsDown("postgres://ignored", cfg, 1); err == nil { + t.Fatalf("expected RunMigrationsDown error for missing source path") + } + if err := ForceMigrationVersion("postgres://ignored", cfg, 1); err == nil { + t.Fatalf("expected ForceMigrationVersion error for missing source path") + } + if _, err := GetMigrationStatus("postgres://ignored", cfg); err == nil { + t.Fatalf("expected GetMigrationStatus error for missing source path") + } + + closeMigrator(nil) +} diff --git a/smtp/smtp_mailer.go b/smtp/smtp_mailer.go new file mode 100644 index 0000000..f3298d2 --- /dev/null +++ b/smtp/smtp_mailer.go @@ -0,0 +1,345 @@ +package smtp + +import ( + "context" + "crypto/rand" + "crypto/tls" + "encoding/base64" + "encoding/hex" + "fmt" + "net" + "net/smtp" + "strings" + "time" +) + +type SMTPMode string + +const ( + SMTPModeSSL SMTPMode = "ssl" + SMTPModeTLS SMTPMode = "tls" + SMTPModeUnencrypted SMTPMode = "unencrypted" +) + +type SMTPConfig struct { + Host string + Port int + Username string + Password string + From string + Mode SMTPMode + Auth string +} + +type Mailer interface { + Send(ctx context.Context, to string, subject string, htmlBody string, textBody string) error +} + +type SMTPMailer struct { + cfg SMTPConfig +} + +const defaultSMTPOperationTimeout = 30 * time.Second + +func NewSMTPMailer(cfg SMTPConfig) *SMTPMailer { + return &SMTPMailer{cfg: cfg} +} + +func (m *SMTPMailer) Send(ctx context.Context, to string, subject string, htmlBody string, textBody string) error { + if strings.TrimSpace(m.cfg.Host) == "" || m.cfg.Port == 0 || strings.TrimSpace(m.cfg.From) == "" { + return fmt.Errorf("smtp is not configured") + } + if ctx == nil { + ctx = context.Background() + } + if err := ctx.Err(); err != nil { + return err + } + + now := time.Now() + addr := fmt.Sprintf("%s:%d", m.cfg.Host, m.cfg.Port) + deadline := smtpSendDeadline(ctx, now) + if !deadline.After(now) { + if err := ctx.Err(); err != nil { + return err + } + return context.DeadlineExceeded + } + + var client *smtp.Client + var err error + + switch m.cfg.Mode { + case SMTPModeSSL: + client, err = dialSSL(ctx, addr, m.cfg.Host, deadline) + case SMTPModeUnencrypted: + client, err = dialPlain(ctx, addr, m.cfg.Host, deadline) + case SMTPModeTLS: + fallthrough + default: + client, err = dialStartTLS(ctx, addr, m.cfg.Host, deadline) + } + if err != nil { + return err + } + defer client.Close() + done := make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + _ = client.Close() + case <-done: + } + }() + defer close(done) + + if m.cfg.Username != "" { + if err := authenticate(client, m.cfg); err != nil { + if ctx.Err() != nil { + return ctx.Err() + } + return err + } + } + + if err := client.Mail(m.cfg.From); err != nil { + if ctx.Err() != nil { + return ctx.Err() + } + return fmt.Errorf("smtp mail from: %w", err) + } + if err := client.Rcpt(to); err != nil { + if ctx.Err() != nil { + return ctx.Err() + } + return fmt.Errorf("smtp rcpt to: %w", err) + } + + writer, err := client.Data() + if err != nil { + if ctx.Err() != nil { + return ctx.Err() + } + return fmt.Errorf("smtp data: %w", err) + } + + message := buildMIMEMessage(m.cfg.From, to, subject, textBody, htmlBody) + if _, err := writer.Write([]byte(message)); err != nil { + _ = writer.Close() + if ctx.Err() != nil { + return ctx.Err() + } + return fmt.Errorf("smtp write body: %w", err) + } + if err := writer.Close(); err != nil { + if ctx.Err() != nil { + return ctx.Err() + } + return fmt.Errorf("smtp close body: %w", err) + } + if err := client.Quit(); err != nil { + if ctx.Err() != nil { + return ctx.Err() + } + return fmt.Errorf("smtp quit: %w", err) + } + return nil +} + +func smtpSendDeadline(ctx context.Context, now time.Time) time.Time { + if ctx != nil { + if deadline, ok := ctx.Deadline(); ok { + return deadline + } + } + return now.Add(defaultSMTPOperationTimeout) +} + +func authenticate(client *smtp.Client, cfg SMTPConfig) error { + authMethod := strings.ToLower(strings.TrimSpace(cfg.Auth)) + if authMethod == "" { + authMethod = "auto" + } + + _, methodsRaw := client.Extension("AUTH") + methods := strings.ToUpper(methodsRaw) + + tryLogin := strings.Contains(methods, "LOGIN") + tryPlain := strings.Contains(methods, "PLAIN") + + switch authMethod { + case "none": + return nil + case "login": + return authWithLogin(client, cfg) + case "plain": + return authWithPlain(client, cfg) + case "auto": + if tryLogin { + if err := authWithLogin(client, cfg); err == nil { + return nil + } + } + if tryPlain { + if err := authWithPlain(client, cfg); err == nil { + return nil + } + } + if !tryLogin && !tryPlain { + // Last fallback if server did not advertise methods. + if err := authWithLogin(client, cfg); err == nil { + return nil + } + if err := authWithPlain(client, cfg); err == nil { + return nil + } + } + return fmt.Errorf("smtp auth failed for available methods") + default: + return fmt.Errorf("unsupported smtp auth method: %s", authMethod) + } +} + +func authWithPlain(client *smtp.Client, cfg SMTPConfig) error { + auth := smtp.PlainAuth("", cfg.Username, cfg.Password, cfg.Host) + if err := client.Auth(auth); err != nil { + return fmt.Errorf("smtp plain auth: %w", err) + } + return nil +} + +func authWithLogin(client *smtp.Client, cfg SMTPConfig) error { + if err := client.Auth(loginAuth{username: cfg.Username, password: cfg.Password}); err != nil { + return fmt.Errorf("smtp login auth: %w", err) + } + return nil +} + +type loginAuth struct { + username string + password string +} + +func (a loginAuth) Start(_ *smtp.ServerInfo) (string, []byte, error) { + return "LOGIN", []byte{}, nil +} + +func (a loginAuth) Next(fromServer []byte, more bool) ([]byte, error) { + if !more { + return nil, nil + } + challengeRaw := strings.TrimSpace(string(fromServer)) + challenge := strings.ToLower(challengeRaw) + if decoded, err := base64.StdEncoding.DecodeString(challengeRaw); err == nil { + challenge = strings.ToLower(string(decoded)) + } + switch { + case strings.Contains(challenge, "username"): + return []byte(a.username), nil + case strings.Contains(challenge, "password"): + return []byte(a.password), nil + } + return nil, fmt.Errorf("unexpected smtp login challenge") +} + +func dialSSL(ctx context.Context, addr string, host string, deadline time.Time) (*smtp.Client, error) { + timeout := time.Until(deadline) + if timeout <= 0 { + if err := ctx.Err(); err != nil { + return nil, err + } + return nil, context.DeadlineExceeded + } + + dialer := &tls.Dialer{ + NetDialer: &net.Dialer{Timeout: timeout}, + Config: &tls.Config{ServerName: host}, + } + conn, err := dialer.DialContext(ctx, "tcp", addr) + if err != nil { + if ctx.Err() != nil { + return nil, ctx.Err() + } + return nil, fmt.Errorf("smtp ssl dial: %w", err) + } + _ = conn.SetDeadline(deadline) + client, err := smtp.NewClient(conn, host) + if err != nil { + return nil, fmt.Errorf("smtp new client ssl: %w", err) + } + return client, nil +} + +func dialPlain(ctx context.Context, addr string, host string, deadline time.Time) (*smtp.Client, error) { + timeout := time.Until(deadline) + if timeout <= 0 { + if err := ctx.Err(); err != nil { + return nil, err + } + return nil, context.DeadlineExceeded + } + + dialer := &net.Dialer{Timeout: timeout} + conn, err := dialer.DialContext(ctx, "tcp", addr) + if err != nil { + if ctx.Err() != nil { + return nil, ctx.Err() + } + return nil, fmt.Errorf("smtp plain dial: %w", err) + } + _ = conn.SetDeadline(deadline) + client, err := smtp.NewClient(conn, host) + if err != nil { + return nil, fmt.Errorf("smtp new client plain: %w", err) + } + return client, nil +} + +func dialStartTLS(ctx context.Context, addr string, host string, deadline time.Time) (*smtp.Client, error) { + client, err := dialPlain(ctx, addr, host, deadline) + if err != nil { + return nil, err + } + if ok, _ := client.Extension("STARTTLS"); !ok { + _ = client.Close() + return nil, fmt.Errorf("smtp server does not support STARTTLS") + } + if err := client.StartTLS(&tls.Config{ServerName: host}); err != nil { + _ = client.Close() + if ctx.Err() != nil { + return nil, ctx.Err() + } + return nil, fmt.Errorf("smtp starttls: %w", err) + } + return client, nil +} + +func buildMIMEMessage(from, to, subject, textBody, htmlBody string) string { + boundary := randomMIMEBoundary() + return strings.Join([]string{ + fmt.Sprintf("From: %s", from), + fmt.Sprintf("To: %s", to), + fmt.Sprintf("Subject: %s", subject), + "MIME-Version: 1.0", + fmt.Sprintf("Content-Type: multipart/alternative; boundary=%s", boundary), + "", + fmt.Sprintf("--%s", boundary), + "Content-Type: text/plain; charset=UTF-8", + "", + textBody, + fmt.Sprintf("--%s", boundary), + "Content-Type: text/html; charset=UTF-8", + "", + htmlBody, + fmt.Sprintf("--%s--", boundary), + "", + }, "\r\n") +} + +func randomMIMEBoundary() string { + raw := make([]byte, 12) + if _, err := rand.Read(raw); err != nil { + return fmt.Sprintf("mime-boundary-%d", time.Now().UnixNano()) + } + return "mime-boundary-" + hex.EncodeToString(raw) +} diff --git a/smtp/smtp_mailer_additional_test.go b/smtp/smtp_mailer_additional_test.go new file mode 100644 index 0000000..15c472d --- /dev/null +++ b/smtp/smtp_mailer_additional_test.go @@ -0,0 +1,126 @@ +package smtp + +import ( + "context" + "strings" + "testing" + "time" +) + +func TestSendReturnsConfigurationErrorWhenSMTPMissing(t *testing.T) { + mailer := NewSMTPMailer(SMTPConfig{Host: "", Port: 587, From: "noreply@example.com", Mode: SMTPModeTLS}) + + err := mailer.Send(context.Background(), "to@example.com", "subject", "

body

", "body") + if err == nil { + t.Fatal("expected configuration error") + } + if !strings.Contains(err.Error(), "smtp is not configured") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestSendReturnsDeadlineExceededForExpiredContext(t *testing.T) { + mailer := NewSMTPMailer(SMTPConfig{Host: "smtp.example.com", Port: 587, From: "noreply@example.com", Mode: SMTPModeTLS}) + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-time.Second)) + defer cancel() + + err := mailer.Send(ctx, "to@example.com", "subject", "

body

", "body") + if err == nil { + t.Fatal("expected context deadline exceeded") + } + if err != context.DeadlineExceeded { + t.Fatalf("expected context deadline exceeded, got %v", err) + } +} + +func TestLoginAuthNextHandlesChallenges(t *testing.T) { + auth := loginAuth{username: "alice", password: "s3cret"} + + proto, initial, err := auth.Start(nil) + if err != nil { + t.Fatalf("unexpected start error: %v", err) + } + if proto != "LOGIN" { + t.Fatalf("expected LOGIN auth proto, got %q", proto) + } + if len(initial) != 0 { + t.Fatalf("expected empty initial response, got %q", string(initial)) + } + + value, err := auth.Next([]byte("Username:"), true) + if err != nil { + t.Fatalf("unexpected username challenge error: %v", err) + } + if string(value) != "alice" { + t.Fatalf("expected username response alice, got %q", string(value)) + } + + value, err = auth.Next([]byte("UGFzc3dvcmQ6"), true) + if err != nil { + t.Fatalf("unexpected password challenge error: %v", err) + } + if string(value) != "s3cret" { + t.Fatalf("expected password response s3cret, got %q", string(value)) + } +} + +func TestLoginAuthNextHandlesTerminalAndUnexpectedChallenge(t *testing.T) { + auth := loginAuth{username: "alice", password: "s3cret"} + + value, err := auth.Next([]byte("ignored"), false) + if err != nil { + t.Fatalf("expected nil error when more=false, got %v", err) + } + if value != nil { + t.Fatalf("expected nil value when more=false, got %q", string(value)) + } + + if _, err := auth.Next([]byte("realm"), true); err == nil { + t.Fatal("expected error for unexpected login challenge") + } +} + +func TestBuildMIMEMessageContainsMultipartSections(t *testing.T) { + msg := buildMIMEMessage("from@example.com", "to@example.com", "Subject", "plain body", "

html body

") + + checks := []string{ + "From: from@example.com", + "To: to@example.com", + "Subject: Subject", + "MIME-Version: 1.0", + "Content-Type: text/plain; charset=UTF-8", + "plain body", + "Content-Type: text/html; charset=UTF-8", + "

html body

", + } + for _, snippet := range checks { + if !strings.Contains(msg, snippet) { + t.Fatalf("expected MIME message to contain %q", snippet) + } + } + if !strings.Contains(msg, "\r\n") { + t.Fatal("expected CRLF separators in MIME message") + } + + if !strings.Contains(msg, "Content-Type: multipart/alternative; boundary=mime-boundary-") { + t.Fatalf("expected random mime boundary header, got %q", msg) + } + if !strings.Contains(msg, "--mime-boundary-") { + t.Fatalf("expected mime boundary delimiters in message, got %q", msg) + } +} + +func TestDialHelpersReturnDeadlineExceededWhenDeadlineHasPassed(t *testing.T) { + deadline := time.Now().Add(-time.Second) + ctx := context.Background() + + if _, err := dialPlain(ctx, "smtp.example.com:25", "smtp.example.com", deadline); err != context.DeadlineExceeded { + t.Fatalf("dialPlain expected context deadline exceeded, got %v", err) + } + if _, err := dialSSL(ctx, "smtp.example.com:465", "smtp.example.com", deadline); err != context.DeadlineExceeded { + t.Fatalf("dialSSL expected context deadline exceeded, got %v", err) + } + if _, err := dialStartTLS(ctx, "smtp.example.com:587", "smtp.example.com", deadline); err != context.DeadlineExceeded { + t.Fatalf("dialStartTLS expected context deadline exceeded, got %v", err) + } +} diff --git a/smtp/smtp_mailer_integration_test.go b/smtp/smtp_mailer_integration_test.go new file mode 100644 index 0000000..62d2bcf --- /dev/null +++ b/smtp/smtp_mailer_integration_test.go @@ -0,0 +1,424 @@ +package smtp + +import ( + "bufio" + "context" + "net" + "strconv" + "strings" + "sync" + "testing" + "time" +) + +type smtpTestServerConfig struct { + extensions []string + failAuth bool + failMail bool + failRcpt bool + failData bool + failQuit bool +} + +type smtpTestServerState struct { + authCalls int + mailFrom string + rcptTo string + data string +} + +type smtpTestServer struct { + host string + port int + state *smtpTestServerState + stop func() +} + +func requireSMTPIntegration(t *testing.T) { + t.Helper() + if testing.Short() { + t.Skip("skipping smtp integration test in short mode") + } +} + +func startSMTPTestServer(t *testing.T, cfg smtpTestServerConfig) *smtpTestServer { + t.Helper() + requireSMTPIntegration(t) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen smtp test server: %v", err) + } + + state := &smtpTestServerState{} + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + conn, err := ln.Accept() + if err != nil { + return + } + defer conn.Close() + + r := bufio.NewReader(conn) + w := bufio.NewWriter(conn) + write := func(line string) bool { + if _, err := w.WriteString(line + "\r\n"); err != nil { + return false + } + return w.Flush() == nil + } + if !write("220 localhost ESMTP ready") { + return + } + + inData := false + var dataBuf strings.Builder + for { + line, err := r.ReadString('\n') + if err != nil { + return + } + line = strings.TrimRight(line, "\r\n") + + if inData { + if line == "." { + state.data = dataBuf.String() + if !write("250 2.0.0 queued") { + return + } + inData = false + continue + } + dataBuf.WriteString(line) + dataBuf.WriteString("\n") + continue + } + + parts := strings.SplitN(line, " ", 2) + cmd := strings.ToUpper(parts[0]) + arg := "" + if len(parts) > 1 { + arg = parts[1] + } + + switch cmd { + case "EHLO", "HELO": + lines := append([]string{"localhost"}, cfg.extensions...) + for i, ext := range lines { + prefix := "250-" + if i == len(lines)-1 { + prefix = "250 " + } + if !write(prefix + ext) { + return + } + } + case "AUTH": + state.authCalls++ + if cfg.failAuth { + if !write("535 5.7.8 auth failed") { + return + } + continue + } + authArg := strings.ToUpper(strings.TrimSpace(arg)) + switch { + case strings.HasPrefix(authArg, "PLAIN"): + if !write("235 2.7.0 authenticated") { + return + } + case strings.HasPrefix(authArg, "LOGIN"): + if !write("334 VXNlcm5hbWU6") { + return + } + if _, err := r.ReadString('\n'); err != nil { + return + } + if !write("334 UGFzc3dvcmQ6") { + return + } + if _, err := r.ReadString('\n'); err != nil { + return + } + if !write("235 2.7.0 authenticated") { + return + } + default: + if !write("504 5.5.4 unsupported auth mechanism") { + return + } + } + case "MAIL": + state.mailFrom = arg + if cfg.failMail { + if !write("550 5.1.0 sender rejected") { + return + } + continue + } + if !write("250 2.1.0 ok") { + return + } + case "RCPT": + state.rcptTo = arg + if cfg.failRcpt { + if !write("550 5.1.1 recipient rejected") { + return + } + continue + } + if !write("250 2.1.5 ok") { + return + } + case "DATA": + if cfg.failData { + if !write("554 5.5.0 data rejected") { + return + } + continue + } + if !write("354 end data with .") { + return + } + inData = true + dataBuf.Reset() + case "QUIT": + if cfg.failQuit { + _ = write("554 5.5.1 quit failed") + return + } + _ = write("221 2.0.0 bye") + return + default: + if !write("250 2.0.0 ok") { + return + } + } + } + }() + + tcpAddr := ln.Addr().(*net.TCPAddr) + return &smtpTestServer{ + host: "127.0.0.1", + port: tcpAddr.Port, + state: state, + stop: func() { + _ = ln.Close() + wg.Wait() + }, + } +} + +func TestSendUnencryptedSuccess(t *testing.T) { + srv := startSMTPTestServer(t, smtpTestServerConfig{}) + defer srv.stop() + + mailer := NewSMTPMailer(SMTPConfig{ + Host: srv.host, + Port: srv.port, + From: "noreply@example.com", + Mode: SMTPModeUnencrypted, + }) + + err := mailer.Send(context.Background(), "to@example.com", "subject", "

hello

", "hello") + if err != nil { + t.Fatalf("Send: %v", err) + } + if !strings.Contains(srv.state.mailFrom, "noreply@example.com") { + t.Fatalf("expected mail from to include sender, got %q", srv.state.mailFrom) + } + if !strings.Contains(srv.state.rcptTo, "to@example.com") { + t.Fatalf("expected rcpt to include recipient, got %q", srv.state.rcptTo) + } + if !strings.Contains(srv.state.data, "Subject: subject") { + t.Fatalf("expected message data to include subject, got %q", srv.state.data) + } +} + +func TestSendUnencryptedErrorPaths(t *testing.T) { + tests := []struct { + name string + cfg smtpTestServerConfig + wantErr string + }{ + {name: "mail from error", cfg: smtpTestServerConfig{failMail: true}, wantErr: "smtp mail from"}, + {name: "rcpt error", cfg: smtpTestServerConfig{failRcpt: true}, wantErr: "smtp rcpt to"}, + {name: "data error", cfg: smtpTestServerConfig{failData: true}, wantErr: "smtp data"}, + {name: "quit error", cfg: smtpTestServerConfig{failQuit: true}, wantErr: "smtp quit"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + srv := startSMTPTestServer(t, tc.cfg) + defer srv.stop() + + mailer := NewSMTPMailer(SMTPConfig{ + Host: srv.host, + Port: srv.port, + From: "noreply@example.com", + Mode: SMTPModeUnencrypted, + }) + err := mailer.Send(context.Background(), "to@example.com", "subject", "

hello

", "hello") + if err == nil || !strings.Contains(err.Error(), tc.wantErr) { + t.Fatalf("expected %q error, got %v", tc.wantErr, err) + } + }) + } +} + +func TestSendAuthModes(t *testing.T) { + t.Run("auth none does not issue AUTH command", func(t *testing.T) { + srv := startSMTPTestServer(t, smtpTestServerConfig{extensions: []string{"AUTH PLAIN LOGIN"}}) + defer srv.stop() + + mailer := NewSMTPMailer(SMTPConfig{ + Host: srv.host, + Port: srv.port, + From: "noreply@example.com", + Mode: SMTPModeUnencrypted, + Username: "alice", + Password: "secret", + Auth: "none", + }) + + if err := mailer.Send(context.Background(), "to@example.com", "subject", "

hello

", "hello"); err != nil { + t.Fatalf("Send: %v", err) + } + if srv.state.authCalls != 0 { + t.Fatalf("expected no AUTH command, got %d", srv.state.authCalls) + } + }) + + t.Run("auth plain", func(t *testing.T) { + srv := startSMTPTestServer(t, smtpTestServerConfig{extensions: []string{"AUTH PLAIN"}}) + defer srv.stop() + + mailer := NewSMTPMailer(SMTPConfig{ + Host: srv.host, + Port: srv.port, + From: "noreply@example.com", + Mode: SMTPModeUnencrypted, + Username: "alice", + Password: "secret", + Auth: "plain", + }) + + if err := mailer.Send(context.Background(), "to@example.com", "subject", "

hello

", "hello"); err != nil { + t.Fatalf("Send: %v", err) + } + if srv.state.authCalls == 0 { + t.Fatal("expected AUTH command for plain auth") + } + }) + + t.Run("auth login", func(t *testing.T) { + srv := startSMTPTestServer(t, smtpTestServerConfig{extensions: []string{"AUTH LOGIN"}}) + defer srv.stop() + + mailer := NewSMTPMailer(SMTPConfig{ + Host: srv.host, + Port: srv.port, + From: "noreply@example.com", + Mode: SMTPModeUnencrypted, + Username: "alice", + Password: "secret", + Auth: "login", + }) + + if err := mailer.Send(context.Background(), "to@example.com", "subject", "

hello

", "hello"); err != nil { + t.Fatalf("Send: %v", err) + } + if srv.state.authCalls == 0 { + t.Fatal("expected AUTH command for login auth") + } + }) + + t.Run("auth auto failure", func(t *testing.T) { + srv := startSMTPTestServer(t, smtpTestServerConfig{failAuth: true}) + defer srv.stop() + + mailer := NewSMTPMailer(SMTPConfig{ + Host: srv.host, + Port: srv.port, + From: "noreply@example.com", + Mode: SMTPModeUnencrypted, + Username: "alice", + Password: "secret", + Auth: "auto", + }) + + err := mailer.Send(context.Background(), "to@example.com", "subject", "

hello

", "hello") + if err == nil || !strings.Contains(err.Error(), "smtp auth failed") { + t.Fatalf("expected smtp auth failed error, got %v", err) + } + }) + + t.Run("unsupported auth method", func(t *testing.T) { + srv := startSMTPTestServer(t, smtpTestServerConfig{}) + defer srv.stop() + + mailer := NewSMTPMailer(SMTPConfig{ + Host: srv.host, + Port: srv.port, + From: "noreply@example.com", + Mode: SMTPModeUnencrypted, + Username: "alice", + Password: "secret", + Auth: "weird", + }) + + err := mailer.Send(context.Background(), "to@example.com", "subject", "

hello

", "hello") + if err == nil || !strings.Contains(err.Error(), "unsupported smtp auth method") { + t.Fatalf("expected unsupported auth method error, got %v", err) + } + }) +} + +func TestSendTLSFailsWhenStartTLSExtensionMissing(t *testing.T) { + srv := startSMTPTestServer(t, smtpTestServerConfig{}) + defer srv.stop() + + mailer := NewSMTPMailer(SMTPConfig{ + Host: srv.host, + Port: srv.port, + From: "noreply@example.com", + Mode: SMTPModeTLS, + }) + + err := mailer.Send(context.Background(), "to@example.com", "subject", "

hello

", "hello") + if err == nil || !strings.Contains(err.Error(), "does not support STARTTLS") { + t.Fatalf("expected STARTTLS unsupported error, got %v", err) + } +} + +func TestDialHelpersWrapDialErrors(t *testing.T) { + requireSMTPIntegration(t) + + ctx := context.Background() + deadline := time.Now().Add(500 * time.Millisecond) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen ephemeral: %v", err) + } + addr := ln.Addr().String() + _ = ln.Close() + host, portStr, err := net.SplitHostPort(addr) + if err != nil { + t.Fatalf("split host port: %v", err) + } + port, err := strconv.Atoi(portStr) + if err != nil { + t.Fatalf("parse port: %v", err) + } + unusedAddr := net.JoinHostPort(host, strconv.Itoa(port)) + + if _, err := dialPlain(ctx, unusedAddr, host, deadline); err == nil || !strings.Contains(err.Error(), "smtp plain dial") { + t.Fatalf("expected wrapped plain dial error, got %v", err) + } + if _, err := dialSSL(ctx, unusedAddr, host, deadline); err == nil || !strings.Contains(err.Error(), "smtp ssl dial") { + t.Fatalf("expected wrapped ssl dial error, got %v", err) + } +} diff --git a/smtp/smtp_mailer_test.go b/smtp/smtp_mailer_test.go new file mode 100644 index 0000000..796ccb4 --- /dev/null +++ b/smtp/smtp_mailer_test.go @@ -0,0 +1,47 @@ +package smtp + +import ( + "context" + "errors" + "testing" + "time" +) + +func TestSendHonorsCanceledContext(t *testing.T) { + mailer := NewSMTPMailer(SMTPConfig{ + Host: "smtp.example.com", + Port: 587, + From: "noreply@example.com", + Mode: SMTPModeTLS, + }) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := mailer.Send(ctx, "to@example.com", "subject", "

hello

", "hello") + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected %v, got %v", context.Canceled, err) + } +} + +func TestSMTPSendDeadlineUsesContextDeadline(t *testing.T) { + now := time.Date(2026, time.January, 1, 12, 0, 0, 0, time.UTC) + want := now.Add(2 * time.Minute) + ctx, cancel := context.WithDeadline(context.Background(), want) + defer cancel() + + got := smtpSendDeadline(ctx, now) + if !got.Equal(want) { + t.Fatalf("expected context deadline %v, got %v", want, got) + } +} + +func TestSMTPSendDeadlineUsesDefaultWhenContextHasNoDeadline(t *testing.T) { + now := time.Date(2026, time.January, 1, 12, 0, 0, 0, time.UTC) + want := now.Add(defaultSMTPOperationTimeout) + + got := smtpSendDeadline(context.Background(), now) + if !got.Equal(want) { + t.Fatalf("expected default deadline %v, got %v", want, got) + } +} diff --git a/worker/poller.go b/worker/poller.go new file mode 100644 index 0000000..81731b6 --- /dev/null +++ b/worker/poller.go @@ -0,0 +1,70 @@ +package worker + +import ( + "context" + "log/slog" + "time" +) + +const defaultTaskName = "batch worker" + +// BatchRunner executes one batch and returns the number of processed records. +// Returning 0 stops the current drain cycle. +type BatchRunner func(ctx context.Context, limit int) (int, error) + +func Run(ctx context.Context, runner BatchRunner, interval time.Duration, batchSize int, logger *slog.Logger, taskName string) { + if runner == nil || interval <= 0 || batchSize <= 0 { + return + } + if logger == nil { + logger = slog.Default() + } + if taskName == "" { + taskName = defaultTaskName + } + + RunOnce(ctx, runner, batchSize, logger, taskName) + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + RunOnce(ctx, runner, batchSize, logger, taskName) + } + } +} + +func RunOnce(ctx context.Context, runner BatchRunner, batchSize int, logger *slog.Logger, taskName string) { + if runner == nil || batchSize <= 0 { + return + } + if logger == nil { + logger = slog.Default() + } + if taskName == "" { + taskName = defaultTaskName + } + + totalProcessed := 0 + for { + processed, err := runner(ctx, batchSize) + if err != nil { + if ctx.Err() == nil { + logger.Warn(taskName+" run failed", "error", err) + } + return + } + if processed <= 0 { + break + } + totalProcessed += processed + } + + if totalProcessed > 0 { + logger.Info(taskName+" completed", "processed", totalProcessed) + } +} diff --git a/worker/poller_test.go b/worker/poller_test.go new file mode 100644 index 0000000..e39a5f0 --- /dev/null +++ b/worker/poller_test.go @@ -0,0 +1,125 @@ +package worker + +import ( + "context" + "errors" + "io" + "log/slog" + "sync" + "testing" + "time" +) + +type fakeScheduledPostPromoter struct { + mu sync.Mutex + calls int + limits []int + results []int + errAt int + callHook func(int) +} + +func (f *fakeScheduledPostPromoter) PromoteDueScheduled(_ context.Context, limit int) (int, error) { + f.mu.Lock() + defer f.mu.Unlock() + + f.calls++ + callNo := f.calls + f.limits = append(f.limits, limit) + if f.callHook != nil { + f.callHook(callNo) + } + + if f.errAt > 0 && callNo == f.errAt { + return 0, errors.New("boom") + } + + if len(f.results) == 0 { + return 0, nil + } + + result := f.results[0] + f.results = f.results[1:] + return result, nil +} + +func (f *fakeScheduledPostPromoter) Calls() int { + f.mu.Lock() + defer f.mu.Unlock() + return f.calls +} + +func (f *fakeScheduledPostPromoter) Limits() []int { + f.mu.Lock() + defer f.mu.Unlock() + out := make([]int, len(f.limits)) + copy(out, f.limits) + return out +} + +func newDiscardLogger() *slog.Logger { + return slog.New(slog.NewTextHandler(io.Discard, nil)) +} + +func TestPromoteDueScheduledOnceBatchesUntilEmpty(t *testing.T) { + repo := &fakeScheduledPostPromoter{results: []int{2, 3, 0}} + + RunOnce(context.Background(), repo.PromoteDueScheduled, 50, newDiscardLogger(), "scheduled post promotion") + + if got := repo.Calls(); got != 3 { + t.Fatalf("expected 3 calls, got %d", got) + } + limits := repo.Limits() + if len(limits) != 3 || limits[0] != 50 || limits[1] != 50 || limits[2] != 50 { + t.Fatalf("unexpected limits: %+v", limits) + } +} + +func TestPromoteDueScheduledOnceStopsOnError(t *testing.T) { + repo := &fakeScheduledPostPromoter{ + results: []int{4, 4}, + errAt: 2, + } + + RunOnce(context.Background(), repo.PromoteDueScheduled, 25, newDiscardLogger(), "scheduled post promotion") + + if got := repo.Calls(); got != 2 { + t.Fatalf("expected 2 calls before stop, got %d", got) + } +} + +func TestRunScheduledPostPromoterStopsOnContextCancel(t *testing.T) { + callCh := make(chan int, 8) + repo := &fakeScheduledPostPromoter{ + results: []int{0, 0, 0, 0}, + callHook: func(call int) { + callCh <- call + }, + } + + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + go func() { + Run(ctx, repo.PromoteDueScheduled, 10*time.Millisecond, 10, newDiscardLogger(), "scheduled post promotion") + close(done) + }() + + timeout := time.After(300 * time.Millisecond) + for { + select { + case <-callCh: + if repo.Calls() >= 2 { + cancel() + select { + case <-done: + return + case <-time.After(200 * time.Millisecond): + t.Fatal("promoter did not stop after context cancel") + } + } + case <-timeout: + cancel() + t.Fatal("timed out waiting for promoter calls") + } + } +}