add core lib

This commit is contained in:
2026-03-01 03:04:10 +01:00
parent 9a6818ea3c
commit baa764befd
22 changed files with 2353 additions and 0 deletions

62
.drone.yml Normal file
View File

@@ -0,0 +1,62 @@
---
kind: pipeline
type: docker
name: go-core-ci
trigger:
branch:
- main
- develop
event:
- push
- pull_request
environment:
GOPROXY: https://nexus.beatrice.wtf/repository/go-group/,direct
GOPRIVATE: git.beatrice.wtf/panic.haus/*
GONOSUMDB: git.beatrice.wtf/panic.haus/*
steps:
- name: test
image: golang:1.26
commands:
- go mod download
- go test ./... -count=1
- name: build
image: golang:1.26
commands:
- go build ./...
---
kind: pipeline
type: docker
name: go-core-release
trigger:
event:
- tag
ref:
- refs/tags/v*
environment:
GOPROXY: https://nexus.beatrice.wtf/repository/go-group/,direct
GOPRIVATE: git.beatrice.wtf/panic.haus/*
GONOSUMDB: git.beatrice.wtf/panic.haus/*
steps:
- name: test
image: golang:1.26
commands:
- go mod download
- go test ./... -count=1
- name: build
image: golang:1.26
commands:
- go build ./...
- name: warm-proxy
image: golang:1.26
commands:
- go list -m git.beatrice.wtf/panic.haus/go-core@${DRONE_TAG}

2
.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
coverage.out
bin/

12
README.md Normal file
View File

@@ -0,0 +1,12 @@
# go-core
Reusable backend infrastructure module.
## Packages
- `dotenv`: load dotenv candidates while preserving existing env vars.
- `dbpool`: PostgreSQL pool bootstrap for pgx.
- `migrate`: migration helpers around `golang-migrate`.
- `smtp`: SMTP mailer implementation.
- `middleware`: reusable HTTP middleware (`CORS`, IP rate limiter).
- `worker`: generic batch poller/runner utilities.

68
dbpool/postgres.go Normal file
View File

@@ -0,0 +1,68 @@
package dbpool
import (
"context"
"fmt"
"time"
"github.com/jackc/pgx/v5/pgxpool"
)
type PoolConfig struct {
MaxOpenConns int
MinIdleConns int
MaxConnLifetime time.Duration
MaxConnIdleTime time.Duration
HealthCheckPeriod time.Duration
ConnectionAcquireWait time.Duration
}
func NewPostgresPool(ctx context.Context, databaseURL string, pool PoolConfig) (*pgxpool.Pool, error) {
cfg, err := pgxpool.ParseConfig(databaseURL)
if err != nil {
return nil, fmt.Errorf("parse postgres pool config: %w", err)
}
if pool.MaxOpenConns <= 0 {
pool.MaxOpenConns = 25
}
if pool.MinIdleConns < 0 {
pool.MinIdleConns = 0
}
if pool.MinIdleConns > pool.MaxOpenConns {
pool.MinIdleConns = pool.MaxOpenConns
}
if pool.MaxConnLifetime <= 0 {
pool.MaxConnLifetime = 30 * time.Minute
}
if pool.MaxConnIdleTime <= 0 {
pool.MaxConnIdleTime = 5 * time.Minute
}
if pool.HealthCheckPeriod <= 0 {
pool.HealthCheckPeriod = time.Minute
}
if pool.ConnectionAcquireWait <= 0 {
pool.ConnectionAcquireWait = 10 * time.Second
}
cfg.MaxConns = int32(pool.MaxOpenConns)
cfg.MinConns = int32(pool.MinIdleConns)
cfg.MaxConnLifetime = pool.MaxConnLifetime
cfg.MaxConnIdleTime = pool.MaxConnIdleTime
cfg.HealthCheckPeriod = pool.HealthCheckPeriod
pingCtx, cancel := context.WithTimeout(ctx, pool.ConnectionAcquireWait)
defer cancel()
poolConn, err := pgxpool.NewWithConfig(ctx, cfg)
if err != nil {
return nil, fmt.Errorf("open postgres pool: %w", err)
}
if err := poolConn.Ping(pingCtx); err != nil {
poolConn.Close()
return nil, fmt.Errorf("ping postgres: %w", err)
}
return poolConn, nil
}

26
dbpool/postgres_test.go Normal file
View File

@@ -0,0 +1,26 @@
package dbpool
import (
"context"
"strings"
"testing"
"time"
)
func TestNewPostgresPool_ParseConfigError(t *testing.T) {
_, err := NewPostgresPool(context.Background(), "://invalid-url", PoolConfig{})
if err == nil || !strings.Contains(err.Error(), "parse postgres pool config") {
t.Fatalf("expected parse config error, got %v", err)
}
}
func TestNewPostgresPool_PingError(t *testing.T) {
_, err := NewPostgresPool(
context.Background(),
"postgres://postgres:postgres@127.0.0.1:1/appdb?sslmode=disable",
PoolConfig{ConnectionAcquireWait: 20 * time.Millisecond},
)
if err == nil || !strings.Contains(err.Error(), "ping postgres") {
t.Fatalf("expected ping error, got %v", err)
}
}

52
dotenv/dotenv.go Normal file
View File

@@ -0,0 +1,52 @@
package config
import (
"bufio"
"os"
"strings"
)
// LoadDotEnvCandidates loads KEY=VALUE pairs from the first existing file in paths.
// Existing process env vars are preserved (file values only fill missing keys).
func LoadDotEnvCandidates(paths []string) {
for _, path := range paths {
if loadDotEnvFile(path) {
return
}
}
}
func loadDotEnvFile(path string) bool {
file, err := os.Open(path)
if err != nil {
return false
}
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
key, value, ok := strings.Cut(line, "=")
if !ok {
continue
}
key = strings.TrimSpace(key)
value = strings.TrimSpace(value)
value = strings.Trim(value, `"'`)
if key == "" {
continue
}
if _, exists := os.LookupEnv(key); exists {
continue
}
_ = os.Setenv(key, value)
}
return true
}

76
dotenv/dotenv_test.go Normal file
View File

@@ -0,0 +1,76 @@
package config
import (
"os"
"path/filepath"
"testing"
)
func TestLoadDotEnvFileParsesAndPreservesExistingValues(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, ".env")
content := "" +
"# comment\n" +
"KEY1 = value1\n" +
"KEY2='quoted value'\n" +
"KEY3=\"double quoted\"\n" +
"INVALID_LINE\n" +
"=missing_key\n"
if err := os.WriteFile(path, []byte(content), 0o600); err != nil {
t.Fatalf("write env file: %v", err)
}
t.Setenv("KEY1", "existing")
_ = os.Unsetenv("KEY2")
_ = os.Unsetenv("KEY3")
if ok := loadDotEnvFile(path); !ok {
t.Fatal("expected loadDotEnvFile to return true")
}
if got := os.Getenv("KEY1"); got != "existing" {
t.Fatalf("expected KEY1 to preserve existing value, got %q", got)
}
if got := os.Getenv("KEY2"); got != "quoted value" {
t.Fatalf("expected KEY2 parsed value, got %q", got)
}
if got := os.Getenv("KEY3"); got != "double quoted" {
t.Fatalf("expected KEY3 parsed value, got %q", got)
}
}
func TestLoadDotEnvCandidatesStopsAtFirstExistingFile(t *testing.T) {
dir := t.TempDir()
first := filepath.Join(dir, "first.env")
second := filepath.Join(dir, "second.env")
if err := os.WriteFile(first, []byte("KEY_A=first\nKEY_B=from_first\n"), 0o600); err != nil {
t.Fatalf("write first env file: %v", err)
}
if err := os.WriteFile(second, []byte("KEY_A=second\nKEY_C=from_second\n"), 0o600); err != nil {
t.Fatalf("write second env file: %v", err)
}
_ = os.Unsetenv("KEY_A")
_ = os.Unsetenv("KEY_B")
_ = os.Unsetenv("KEY_C")
LoadDotEnvCandidates([]string{filepath.Join(dir, "missing.env"), first, second})
if got := os.Getenv("KEY_A"); got != "first" {
t.Fatalf("expected KEY_A from first file, got %q", got)
}
if got := os.Getenv("KEY_B"); got != "from_first" {
t.Fatalf("expected KEY_B from first file, got %q", got)
}
if got := os.Getenv("KEY_C"); got != "" {
t.Fatalf("expected KEY_C to remain unset, got %q", got)
}
}
func TestLoadDotEnvFileReturnsFalseWhenMissing(t *testing.T) {
if ok := loadDotEnvFile(filepath.Join(t.TempDir(), "does-not-exist.env")); ok {
t.Fatal("expected loadDotEnvFile to return false for missing file")
}
}

17
go.mod Normal file
View File

@@ -0,0 +1,17 @@
module git.beatrice.wtf/panic.haus/go-core
go 1.25
require (
github.com/golang-migrate/migrate/v4 v4.19.1
github.com/jackc/pgx/v5 v5.8.0
)
require (
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/lib/pq v1.10.9 // indirect
golang.org/x/sync v0.18.0 // indirect
golang.org/x/text v0.31.0 // indirect
)

93
go.sum Normal file
View File

@@ -0,0 +1,93 @@
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0=
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M=
github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE=
github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dhui/dktest v0.4.6 h1:+DPKyScKSEp3VLtbMDHcUq6V5Lm5zfZZVb0Sk7Ahom4=
github.com/dhui/dktest v0.4.6/go.mod h1:JHTSYDtKkvFNFHJKqCzVzqXecyv+tKt8EzceOmQOgbU=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
github.com/docker/docker v28.3.3+incompatible h1:Dypm25kh4rmk49v1eiVbsAtpAsYURjYkaKubwuBdxEI=
github.com/docker/docker v28.3.3+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c=
github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc=
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug=
github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY=
github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/golang-migrate/migrate/v4 v4.19.1 h1:OCyb44lFuQfYXYLx1SCxPZQGU7mcaZ7gH9yH4jSFbBA=
github.com/golang-migrate/migrate/v4 v4.19.1/go.mod h1:CTcgfjxhaUtsLipnLoQRWCrjYXycRz/g5+RWDuYgPrE=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo=
github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0=
github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y=
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug=
github.com/opencontainers/image-spec v1.1.0/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 h1:F7Jx+6hwnZ41NSFTO5q4LYDtJRXBf2PD0rNBkeB/lus=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0/go.mod h1:UHB22Z8QsdRDrnAtX4PntOl36ajSxcdUMt1sF7Y6E7Q=
go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ=
go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I=
go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/WgbsdpcPoZE=
go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E=
go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4=
go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0=
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

67
middleware/cors.go Normal file
View File

@@ -0,0 +1,67 @@
package middleware
import (
"net"
"net/http"
"net/url"
"strings"
)
type CORSMiddleware struct {
allowedOrigins map[string]struct{}
allowAll bool
}
func NewCORSMiddleware(origins []string) *CORSMiddleware {
m := &CORSMiddleware{allowedOrigins: map[string]struct{}{}}
for _, origin := range origins {
if origin == "*" {
m.allowAll = true
continue
}
m.allowedOrigins[origin] = struct{}{}
}
return m
}
func (m *CORSMiddleware) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
if origin != "" && (m.allowAll || m.isAllowed(origin) || isSameHost(origin, r.Host)) {
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Vary", "Origin")
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Access-Control-Allow-Methods", "GET,POST,PUT,PATCH,DELETE,OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type,Authorization")
}
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
return
}
next.ServeHTTP(w, r)
})
}
func (m *CORSMiddleware) isAllowed(origin string) bool {
_, ok := m.allowedOrigins[origin]
return ok
}
func isSameHost(origin string, requestHost string) bool {
parsed, err := url.Parse(origin)
if err != nil || parsed.Host == "" {
return false
}
originHost := parsed.Hostname()
reqHost := requestHost
if strings.Contains(requestHost, ":") {
if host, _, splitErr := net.SplitHostPort(requestHost); splitErr == nil {
reqHost = host
}
}
return strings.EqualFold(originHost, reqHost)
}

137
middleware/cors_test.go Normal file
View File

@@ -0,0 +1,137 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestCORSHandlerAllowsConfiguredOrigin(t *testing.T) {
mw := NewCORSMiddleware([]string{"https://allowed.example"})
nextCalled := false
h := mw.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Origin", "https://allowed.example")
rr := httptest.NewRecorder()
h.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", rr.Code)
}
if rr.Header().Get("Access-Control-Allow-Origin") != "https://allowed.example" {
t.Fatalf("unexpected allow-origin header: %q", rr.Header().Get("Access-Control-Allow-Origin"))
}
if rr.Header().Get("Access-Control-Allow-Credentials") != "true" {
t.Fatalf("expected allow-credentials true, got %q", rr.Header().Get("Access-Control-Allow-Credentials"))
}
if !nextCalled {
t.Fatal("expected next handler to be called")
}
}
func TestCORSHandlerAllowsAnyOriginWhenWildcardConfigured(t *testing.T) {
mw := NewCORSMiddleware([]string{"*"})
h := mw.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Origin", "https://random.example")
rr := httptest.NewRecorder()
h.ServeHTTP(rr, req)
if rr.Header().Get("Access-Control-Allow-Origin") != "https://random.example" {
t.Fatalf("expected wildcard middleware to echo origin, got %q", rr.Header().Get("Access-Control-Allow-Origin"))
}
if rr.Header().Get("Access-Control-Allow-Credentials") != "true" {
t.Fatalf("expected allow-credentials true, got %q", rr.Header().Get("Access-Control-Allow-Credentials"))
}
}
func TestCORSHandlerAllowsSameHostOrigin(t *testing.T) {
mw := NewCORSMiddleware(nil)
h := mw.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "http://api.example:8080/path", nil)
req.Host = "api.example:8080"
req.Header.Set("Origin", "https://api.example")
rr := httptest.NewRecorder()
h.ServeHTTP(rr, req)
if rr.Header().Get("Access-Control-Allow-Origin") != "https://api.example" {
t.Fatalf("expected same-host origin to be allowed, got %q", rr.Header().Get("Access-Control-Allow-Origin"))
}
if rr.Header().Get("Access-Control-Allow-Credentials") != "true" {
t.Fatalf("expected allow-credentials true, got %q", rr.Header().Get("Access-Control-Allow-Credentials"))
}
}
func TestCORSHandlerDoesNotSetHeadersForDisallowedOrigin(t *testing.T) {
mw := NewCORSMiddleware([]string{"https://allowed.example"})
h := mw.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Origin", "https://blocked.example")
rr := httptest.NewRecorder()
h.ServeHTTP(rr, req)
if rr.Header().Get("Access-Control-Allow-Origin") != "" {
t.Fatalf("expected no allow-origin header, got %q", rr.Header().Get("Access-Control-Allow-Origin"))
}
}
func TestCORSHandlerOptionsShortCircuits(t *testing.T) {
mw := NewCORSMiddleware([]string{"https://allowed.example"})
nextCalled := false
h := mw.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodOptions, "/", nil)
req.Header.Set("Origin", "https://allowed.example")
rr := httptest.NewRecorder()
h.ServeHTTP(rr, req)
if rr.Code != http.StatusNoContent {
t.Fatalf("expected 204, got %d", rr.Code)
}
if rr.Header().Get("Access-Control-Allow-Credentials") != "true" {
t.Fatalf("expected allow-credentials true, got %q", rr.Header().Get("Access-Control-Allow-Credentials"))
}
if nextCalled {
t.Fatal("expected next handler not to be called for OPTIONS")
}
}
func TestIsSameHost(t *testing.T) {
tests := []struct {
name string
origin string
requestHost string
want bool
}{
{name: "same host exact", origin: "https://api.example", requestHost: "api.example", want: true},
{name: "same host with port", origin: "https://api.example", requestHost: "api.example:8080", want: true},
{name: "case insensitive", origin: "https://API.EXAMPLE", requestHost: "api.example", want: true},
{name: "different host", origin: "https://api.example", requestHost: "other.example", want: false},
{name: "invalid origin", origin: "://bad", requestHost: "api.example", want: false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := isSameHost(tc.origin, tc.requestHost)
if got != tc.want {
t.Fatalf("isSameHost(%q, %q) = %v, want %v", tc.origin, tc.requestHost, got, tc.want)
}
})
}
}

133
middleware/rate_limit.go Normal file
View File

@@ -0,0 +1,133 @@
package middleware
import (
"net"
"net/http"
"strconv"
"strings"
"sync"
"time"
)
type ipWindowCounter struct {
windowStart time.Time
lastSeen time.Time
count int
}
type IPRateLimiter struct {
mu sync.Mutex
limit int
window time.Duration
ttl time.Duration
maxEntries int
lastCleanup time.Time
entries map[string]*ipWindowCounter
}
const defaultRateLimiterMaxEntries = 10000
func NewIPRateLimiter(limit int, window time.Duration, ttl time.Duration) *IPRateLimiter {
if limit <= 0 {
limit = 60
}
if window <= 0 {
window = time.Minute
}
if ttl <= 0 {
ttl = 10 * time.Minute
}
return &IPRateLimiter{
limit: limit,
window: window,
ttl: ttl,
maxEntries: defaultRateLimiterMaxEntries,
entries: map[string]*ipWindowCounter{},
}
}
func (m *IPRateLimiter) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !m.allow(r) {
w.Header().Set("Retry-After", strconv.Itoa(int(m.window.Seconds())))
http.Error(w, "rate limit exceeded", http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
func (m *IPRateLimiter) allow(r *http.Request) bool {
key := clientKey(r)
now := time.Now()
m.mu.Lock()
defer m.mu.Unlock()
if m.lastCleanup.IsZero() || now.Sub(m.lastCleanup) >= m.window {
m.cleanupLocked(now)
m.lastCleanup = now
}
entry, exists := m.entries[key]
if !exists {
if len(m.entries) >= m.maxEntries {
m.evictOldestLocked()
}
m.entries[key] = &ipWindowCounter{
windowStart: now,
lastSeen: now,
count: 1,
}
return true
}
if now.Sub(entry.windowStart) >= m.window {
entry.windowStart = now
entry.lastSeen = now
entry.count = 1
return true
}
entry.lastSeen = now
if entry.count >= m.limit {
return false
}
entry.count++
return true
}
func (m *IPRateLimiter) cleanupLocked(now time.Time) {
for key, entry := range m.entries {
if now.Sub(entry.lastSeen) > m.ttl {
delete(m.entries, key)
}
}
}
func (m *IPRateLimiter) evictOldestLocked() {
var oldestKey string
var oldestTime time.Time
first := true
for key, entry := range m.entries {
if first || entry.lastSeen.Before(oldestTime) {
oldestKey = key
oldestTime = entry.lastSeen
first = false
}
}
if !first {
delete(m.entries, oldestKey)
}
}
func clientKey(r *http.Request) string {
host := strings.TrimSpace(r.RemoteAddr)
if parsedHost, _, err := net.SplitHostPort(host); err == nil {
host = parsedHost
}
if host == "" {
host = "unknown"
}
return host
}

View File

@@ -0,0 +1,177 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestIPRateLimiterRejectsExcessRequestsAcrossAuthPaths(t *testing.T) {
limiter := NewIPRateLimiter(2, time.Minute, 5*time.Minute)
h := limiter.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}))
for i := 0; i < 2; i++ {
req := httptest.NewRequest(http.MethodPost, "/auth/login", nil)
req.RemoteAddr = "203.0.113.10:12345"
rr := httptest.NewRecorder()
h.ServeHTTP(rr, req)
if rr.Code != http.StatusNoContent {
t.Fatalf("request %d expected 204, got %d", i+1, rr.Code)
}
}
third := httptest.NewRequest(http.MethodPost, "/auth/login", nil)
third.RemoteAddr = "203.0.113.10:12345"
rr := httptest.NewRecorder()
h.ServeHTTP(rr, third)
if rr.Code != http.StatusTooManyRequests {
t.Fatalf("expected 429 for third request, got %d", rr.Code)
}
if rr.Header().Get("Retry-After") == "" {
t.Fatal("expected Retry-After header")
}
otherPath := httptest.NewRequest(http.MethodPost, "/auth/register", nil)
otherPath.RemoteAddr = "203.0.113.10:12345"
rr2 := httptest.NewRecorder()
h.ServeHTTP(rr2, otherPath)
if rr2.Code != http.StatusTooManyRequests {
t.Fatalf("expected shared limiter bucket per IP across paths, got %d", rr2.Code)
}
}
func TestIPRateLimiterCleanupRunsPeriodicallyWithoutEntryThreshold(t *testing.T) {
limiter := NewIPRateLimiter(5, 20*time.Millisecond, 20*time.Millisecond)
first := httptest.NewRequest(http.MethodPost, "/auth/login", nil)
first.RemoteAddr = "203.0.113.10:12345"
if !limiter.allow(first) {
t.Fatal("expected first request to be allowed")
}
time.Sleep(50 * time.Millisecond)
second := httptest.NewRequest(http.MethodPost, "/auth/login", nil)
second.RemoteAddr = "203.0.113.11:12345"
if !limiter.allow(second) {
t.Fatal("expected second request to be allowed")
}
limiter.mu.Lock()
defer limiter.mu.Unlock()
if len(limiter.entries) != 1 {
t.Fatalf("expected stale entries to be cleaned, got %d entries", len(limiter.entries))
}
if _, exists := limiter.entries["203.0.113.11"]; !exists {
t.Fatal("expected current client key to remain after cleanup")
}
}
func TestIPRateLimiterEvictsOldestEntryAtCapacity(t *testing.T) {
limiter := NewIPRateLimiter(5, time.Minute, time.Hour)
limiter.maxEntries = 2
req1 := httptest.NewRequest(http.MethodPost, "/auth/login", nil)
req1.RemoteAddr = "203.0.113.10:12345"
if !limiter.allow(req1) {
t.Fatal("expected first request to be allowed")
}
time.Sleep(2 * time.Millisecond)
req2 := httptest.NewRequest(http.MethodPost, "/auth/login", nil)
req2.RemoteAddr = "203.0.113.11:12345"
if !limiter.allow(req2) {
t.Fatal("expected second request to be allowed")
}
time.Sleep(2 * time.Millisecond)
req3 := httptest.NewRequest(http.MethodPost, "/auth/login", nil)
req3.RemoteAddr = "203.0.113.12:12345"
if !limiter.allow(req3) {
t.Fatal("expected third request to be allowed")
}
limiter.mu.Lock()
defer limiter.mu.Unlock()
if len(limiter.entries) != 2 {
t.Fatalf("expected limiter size to remain capped at 2, got %d", len(limiter.entries))
}
if _, exists := limiter.entries["203.0.113.10"]; exists {
t.Fatal("expected oldest entry to be evicted")
}
if _, exists := limiter.entries["203.0.113.11"]; !exists {
t.Fatal("expected second entry to remain")
}
if _, exists := limiter.entries["203.0.113.12"]; !exists {
t.Fatal("expected newest entry to remain")
}
}
func TestIPRateLimiterWindowResetAllowsRequestAgain(t *testing.T) {
limiter := NewIPRateLimiter(1, 20*time.Millisecond, time.Minute)
req := httptest.NewRequest(http.MethodPost, "/auth/login", nil)
req.RemoteAddr = "203.0.113.10:12345"
if !limiter.allow(req) {
t.Fatal("expected first request to be allowed")
}
if limiter.allow(req) {
t.Fatal("expected second request in same window to be blocked")
}
time.Sleep(25 * time.Millisecond)
if !limiter.allow(req) {
t.Fatal("expected request to be allowed after window reset")
}
}
func TestClientKeyVariants(t *testing.T) {
t.Run("host port", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "203.0.113.10:4321"
if got := clientKey(req); got != "203.0.113.10" {
t.Fatalf("expected parsed host, got %q", got)
}
})
t.Run("raw host", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "203.0.113.11"
if got := clientKey(req); got != "203.0.113.11" {
t.Fatalf("expected raw host, got %q", got)
}
})
t.Run("empty uses unknown", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = " "
if got := clientKey(req); got != "unknown" {
t.Fatalf("expected unknown fallback, got %q", got)
}
})
}
func TestNewIPRateLimiterAppliesDefaults(t *testing.T) {
limiter := NewIPRateLimiter(0, 0, 0)
if limiter.limit != 60 {
t.Fatalf("expected default limit 60, got %d", limiter.limit)
}
if limiter.window != time.Minute {
t.Fatalf("expected default window %v, got %v", time.Minute, limiter.window)
}
if limiter.ttl != 10*time.Minute {
t.Fatalf("expected default ttl %v, got %v", 10*time.Minute, limiter.ttl)
}
if limiter.maxEntries != defaultRateLimiterMaxEntries {
t.Fatalf("expected maxEntries %d, got %d", defaultRateLimiterMaxEntries, limiter.maxEntries)
}
if limiter.entries == nil {
t.Fatal("expected entries map to be initialized")
}
}

150
migrate/migrate.go Normal file
View File

@@ -0,0 +1,150 @@
package migrate
import (
"errors"
"fmt"
"os"
"path/filepath"
"time"
"github.com/golang-migrate/migrate/v4"
_ "github.com/golang-migrate/migrate/v4/database/postgres"
_ "github.com/golang-migrate/migrate/v4/source/file"
)
const (
defaultMigrationsPath = "file://migrations"
)
type MigrationConfig struct {
Path string
LockTimeout time.Duration
StartupTimeout time.Duration
}
type MigrationStatus struct {
Version uint
Dirty bool
}
func RunMigrationsUp(databaseURL string, cfg MigrationConfig) error {
exec := func() error {
m, err := newMigrator(databaseURL, cfg.Path)
if err != nil {
return err
}
defer closeMigrator(m)
if err := m.Up(); err != nil && !errors.Is(err, migrate.ErrNoChange) {
return fmt.Errorf("run migrations up: %w", err)
}
return nil
}
return runWithTimeout(cfg.StartupTimeout, exec)
}
func RunMigrationsDown(databaseURL string, cfg MigrationConfig, steps int) error {
m, err := newMigrator(databaseURL, cfg.Path)
if err != nil {
return err
}
defer closeMigrator(m)
if steps > 0 {
if err := m.Steps(-steps); err != nil && !errors.Is(err, migrate.ErrNoChange) {
return fmt.Errorf("run migrations down %d steps: %w", steps, err)
}
return nil
}
if err := m.Down(); err != nil && !errors.Is(err, migrate.ErrNoChange) {
return fmt.Errorf("run migrations down all: %w", err)
}
return nil
}
func ForceMigrationVersion(databaseURL string, cfg MigrationConfig, version int) error {
m, err := newMigrator(databaseURL, cfg.Path)
if err != nil {
return err
}
defer closeMigrator(m)
if err := m.Force(version); err != nil {
return fmt.Errorf("force migration version %d: %w", version, err)
}
return nil
}
func GetMigrationStatus(databaseURL string, cfg MigrationConfig) (MigrationStatus, error) {
m, err := newMigrator(databaseURL, cfg.Path)
if err != nil {
return MigrationStatus{}, err
}
defer closeMigrator(m)
version, dirty, err := m.Version()
if errors.Is(err, migrate.ErrNilVersion) {
return MigrationStatus{}, nil
}
if err != nil {
return MigrationStatus{}, fmt.Errorf("read migration version: %w", err)
}
return MigrationStatus{Version: version, Dirty: dirty}, nil
}
func ResolveMigrationsPath(pathValue string) string {
if pathValue == "" {
pathValue = defaultMigrationsPath
}
if filepath.IsAbs(pathValue) {
return "file://" + pathValue
}
if len(pathValue) >= len("file://") && pathValue[:len("file://")] == "file://" {
return pathValue
}
wd, err := os.Getwd()
if err != nil {
return defaultMigrationsPath
}
return "file://" + filepath.Join(wd, pathValue)
}
func newMigrator(databaseURL string, path string) (*migrate.Migrate, error) {
migrationsPath := ResolveMigrationsPath(path)
m, err := migrate.New(migrationsPath, databaseURL)
if err != nil {
return nil, fmt.Errorf("open migration driver (%s): %w", migrationsPath, err)
}
return m, nil
}
func closeMigrator(m *migrate.Migrate) {
if m == nil {
return
}
sourceErr, dbErr := m.Close()
if sourceErr != nil || dbErr != nil {
_ = sourceErr
_ = dbErr
}
}
func runWithTimeout(timeout time.Duration, fn func() error) error {
if timeout <= 0 {
return fn()
}
done := make(chan error, 1)
go func() {
done <- fn()
}()
select {
case err := <-done:
return err
case <-time.After(timeout):
return fmt.Errorf("migration timed out after %s", timeout)
}
}

View File

@@ -0,0 +1,67 @@
package migrate
import (
"os"
"path/filepath"
"runtime"
"testing"
)
func TestRunMigrationsUpTwice(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
databaseURL := os.Getenv("INTEGRATION_DB_URL")
if databaseURL == "" {
t.Skip("set INTEGRATION_DB_URL to run migration integration tests")
}
cfg := MigrationConfig{Path: resolveMigrationsPath()}
if err := RunMigrationsUp(databaseURL, cfg); err != nil {
t.Fatalf("first migrate up: %v", err)
}
if err := RunMigrationsUp(databaseURL, cfg); err != nil {
t.Fatalf("second migrate up: %v", err)
}
status, err := GetMigrationStatus(databaseURL, cfg)
if err != nil {
t.Fatalf("get migration status: %v", err)
}
if status.Dirty {
t.Fatalf("expected non-dirty migration status")
}
if status.Version == 0 {
t.Fatalf("expected migration version > 0 after up")
}
}
func resolveMigrationsPath() string {
configured := os.Getenv("DB_MIGRATIONS_PATH")
if configured == "" {
configured = "migrations"
}
if filepath.IsAbs(configured) {
return configured
}
if len(configured) >= len("file://") && configured[:len("file://")] == "file://" {
return configured
}
if _, err := os.Stat(configured); err == nil {
return configured
}
_, thisFile, _, ok := runtime.Caller(0)
if !ok {
return configured
}
moduleRoot := filepath.Clean(filepath.Join(filepath.Dir(thisFile), "..", ".."))
candidate := filepath.Join(moduleRoot, configured)
if _, err := os.Stat(candidate); err == nil {
return candidate
}
return configured
}

View File

@@ -0,0 +1,77 @@
package migrate
import (
"errors"
"os"
"path/filepath"
"strings"
"testing"
"time"
)
func TestResolveMigrationsPath(t *testing.T) {
if got := ResolveMigrationsPath(""); got != "file://migrations" && !strings.HasSuffix(got, "/migrations") {
t.Fatalf("expected default migrations path, got %q", got)
}
if got := ResolveMigrationsPath("file:///tmp/migs"); got != "file:///tmp/migs" {
t.Fatalf("expected passthrough file URI, got %q", got)
}
abs := filepath.Join(string(os.PathSeparator), "tmp", "migs")
if got := ResolveMigrationsPath(abs); got != "file://"+abs {
t.Fatalf("expected absolute path conversion, got %q", got)
}
rel := "migrations"
got := ResolveMigrationsPath(rel)
if !strings.HasPrefix(got, "file://") || !strings.HasSuffix(got, rel) {
t.Fatalf("expected file URI ending with %q, got %q", rel, got)
}
}
func TestRunWithTimeout(t *testing.T) {
if err := runWithTimeout(0, func() error { return nil }); err != nil {
t.Fatalf("expected no error for immediate execution, got %v", err)
}
if err := runWithTimeout(50*time.Millisecond, func() error {
time.Sleep(5 * time.Millisecond)
return nil
}); err != nil {
t.Fatalf("expected fn to finish before timeout, got %v", err)
}
err := runWithTimeout(10*time.Millisecond, func() error {
time.Sleep(50 * time.Millisecond)
return nil
})
if err == nil || !strings.Contains(err.Error(), "timed out") {
t.Fatalf("expected timeout error, got %v", err)
}
expected := errors.New("boom")
err = runWithTimeout(20*time.Millisecond, func() error { return expected })
if !errors.Is(err, expected) {
t.Fatalf("expected function error propagation, got %v", err)
}
}
func TestMigrationAPIs_InvalidSourcePath(t *testing.T) {
cfg := MigrationConfig{Path: "definitely-missing-migrations-dir"}
if err := RunMigrationsUp("postgres://ignored", cfg); err == nil {
t.Fatalf("expected RunMigrationsUp error for missing source path")
}
if err := RunMigrationsDown("postgres://ignored", cfg, 1); err == nil {
t.Fatalf("expected RunMigrationsDown error for missing source path")
}
if err := ForceMigrationVersion("postgres://ignored", cfg, 1); err == nil {
t.Fatalf("expected ForceMigrationVersion error for missing source path")
}
if _, err := GetMigrationStatus("postgres://ignored", cfg); err == nil {
t.Fatalf("expected GetMigrationStatus error for missing source path")
}
closeMigrator(nil)
}

345
smtp/smtp_mailer.go Normal file
View File

@@ -0,0 +1,345 @@
package smtp
import (
"context"
"crypto/rand"
"crypto/tls"
"encoding/base64"
"encoding/hex"
"fmt"
"net"
"net/smtp"
"strings"
"time"
)
type SMTPMode string
const (
SMTPModeSSL SMTPMode = "ssl"
SMTPModeTLS SMTPMode = "tls"
SMTPModeUnencrypted SMTPMode = "unencrypted"
)
type SMTPConfig struct {
Host string
Port int
Username string
Password string
From string
Mode SMTPMode
Auth string
}
type Mailer interface {
Send(ctx context.Context, to string, subject string, htmlBody string, textBody string) error
}
type SMTPMailer struct {
cfg SMTPConfig
}
const defaultSMTPOperationTimeout = 30 * time.Second
func NewSMTPMailer(cfg SMTPConfig) *SMTPMailer {
return &SMTPMailer{cfg: cfg}
}
func (m *SMTPMailer) Send(ctx context.Context, to string, subject string, htmlBody string, textBody string) error {
if strings.TrimSpace(m.cfg.Host) == "" || m.cfg.Port == 0 || strings.TrimSpace(m.cfg.From) == "" {
return fmt.Errorf("smtp is not configured")
}
if ctx == nil {
ctx = context.Background()
}
if err := ctx.Err(); err != nil {
return err
}
now := time.Now()
addr := fmt.Sprintf("%s:%d", m.cfg.Host, m.cfg.Port)
deadline := smtpSendDeadline(ctx, now)
if !deadline.After(now) {
if err := ctx.Err(); err != nil {
return err
}
return context.DeadlineExceeded
}
var client *smtp.Client
var err error
switch m.cfg.Mode {
case SMTPModeSSL:
client, err = dialSSL(ctx, addr, m.cfg.Host, deadline)
case SMTPModeUnencrypted:
client, err = dialPlain(ctx, addr, m.cfg.Host, deadline)
case SMTPModeTLS:
fallthrough
default:
client, err = dialStartTLS(ctx, addr, m.cfg.Host, deadline)
}
if err != nil {
return err
}
defer client.Close()
done := make(chan struct{})
go func() {
select {
case <-ctx.Done():
_ = client.Close()
case <-done:
}
}()
defer close(done)
if m.cfg.Username != "" {
if err := authenticate(client, m.cfg); err != nil {
if ctx.Err() != nil {
return ctx.Err()
}
return err
}
}
if err := client.Mail(m.cfg.From); err != nil {
if ctx.Err() != nil {
return ctx.Err()
}
return fmt.Errorf("smtp mail from: %w", err)
}
if err := client.Rcpt(to); err != nil {
if ctx.Err() != nil {
return ctx.Err()
}
return fmt.Errorf("smtp rcpt to: %w", err)
}
writer, err := client.Data()
if err != nil {
if ctx.Err() != nil {
return ctx.Err()
}
return fmt.Errorf("smtp data: %w", err)
}
message := buildMIMEMessage(m.cfg.From, to, subject, textBody, htmlBody)
if _, err := writer.Write([]byte(message)); err != nil {
_ = writer.Close()
if ctx.Err() != nil {
return ctx.Err()
}
return fmt.Errorf("smtp write body: %w", err)
}
if err := writer.Close(); err != nil {
if ctx.Err() != nil {
return ctx.Err()
}
return fmt.Errorf("smtp close body: %w", err)
}
if err := client.Quit(); err != nil {
if ctx.Err() != nil {
return ctx.Err()
}
return fmt.Errorf("smtp quit: %w", err)
}
return nil
}
func smtpSendDeadline(ctx context.Context, now time.Time) time.Time {
if ctx != nil {
if deadline, ok := ctx.Deadline(); ok {
return deadline
}
}
return now.Add(defaultSMTPOperationTimeout)
}
func authenticate(client *smtp.Client, cfg SMTPConfig) error {
authMethod := strings.ToLower(strings.TrimSpace(cfg.Auth))
if authMethod == "" {
authMethod = "auto"
}
_, methodsRaw := client.Extension("AUTH")
methods := strings.ToUpper(methodsRaw)
tryLogin := strings.Contains(methods, "LOGIN")
tryPlain := strings.Contains(methods, "PLAIN")
switch authMethod {
case "none":
return nil
case "login":
return authWithLogin(client, cfg)
case "plain":
return authWithPlain(client, cfg)
case "auto":
if tryLogin {
if err := authWithLogin(client, cfg); err == nil {
return nil
}
}
if tryPlain {
if err := authWithPlain(client, cfg); err == nil {
return nil
}
}
if !tryLogin && !tryPlain {
// Last fallback if server did not advertise methods.
if err := authWithLogin(client, cfg); err == nil {
return nil
}
if err := authWithPlain(client, cfg); err == nil {
return nil
}
}
return fmt.Errorf("smtp auth failed for available methods")
default:
return fmt.Errorf("unsupported smtp auth method: %s", authMethod)
}
}
func authWithPlain(client *smtp.Client, cfg SMTPConfig) error {
auth := smtp.PlainAuth("", cfg.Username, cfg.Password, cfg.Host)
if err := client.Auth(auth); err != nil {
return fmt.Errorf("smtp plain auth: %w", err)
}
return nil
}
func authWithLogin(client *smtp.Client, cfg SMTPConfig) error {
if err := client.Auth(loginAuth{username: cfg.Username, password: cfg.Password}); err != nil {
return fmt.Errorf("smtp login auth: %w", err)
}
return nil
}
type loginAuth struct {
username string
password string
}
func (a loginAuth) Start(_ *smtp.ServerInfo) (string, []byte, error) {
return "LOGIN", []byte{}, nil
}
func (a loginAuth) Next(fromServer []byte, more bool) ([]byte, error) {
if !more {
return nil, nil
}
challengeRaw := strings.TrimSpace(string(fromServer))
challenge := strings.ToLower(challengeRaw)
if decoded, err := base64.StdEncoding.DecodeString(challengeRaw); err == nil {
challenge = strings.ToLower(string(decoded))
}
switch {
case strings.Contains(challenge, "username"):
return []byte(a.username), nil
case strings.Contains(challenge, "password"):
return []byte(a.password), nil
}
return nil, fmt.Errorf("unexpected smtp login challenge")
}
func dialSSL(ctx context.Context, addr string, host string, deadline time.Time) (*smtp.Client, error) {
timeout := time.Until(deadline)
if timeout <= 0 {
if err := ctx.Err(); err != nil {
return nil, err
}
return nil, context.DeadlineExceeded
}
dialer := &tls.Dialer{
NetDialer: &net.Dialer{Timeout: timeout},
Config: &tls.Config{ServerName: host},
}
conn, err := dialer.DialContext(ctx, "tcp", addr)
if err != nil {
if ctx.Err() != nil {
return nil, ctx.Err()
}
return nil, fmt.Errorf("smtp ssl dial: %w", err)
}
_ = conn.SetDeadline(deadline)
client, err := smtp.NewClient(conn, host)
if err != nil {
return nil, fmt.Errorf("smtp new client ssl: %w", err)
}
return client, nil
}
func dialPlain(ctx context.Context, addr string, host string, deadline time.Time) (*smtp.Client, error) {
timeout := time.Until(deadline)
if timeout <= 0 {
if err := ctx.Err(); err != nil {
return nil, err
}
return nil, context.DeadlineExceeded
}
dialer := &net.Dialer{Timeout: timeout}
conn, err := dialer.DialContext(ctx, "tcp", addr)
if err != nil {
if ctx.Err() != nil {
return nil, ctx.Err()
}
return nil, fmt.Errorf("smtp plain dial: %w", err)
}
_ = conn.SetDeadline(deadline)
client, err := smtp.NewClient(conn, host)
if err != nil {
return nil, fmt.Errorf("smtp new client plain: %w", err)
}
return client, nil
}
func dialStartTLS(ctx context.Context, addr string, host string, deadline time.Time) (*smtp.Client, error) {
client, err := dialPlain(ctx, addr, host, deadline)
if err != nil {
return nil, err
}
if ok, _ := client.Extension("STARTTLS"); !ok {
_ = client.Close()
return nil, fmt.Errorf("smtp server does not support STARTTLS")
}
if err := client.StartTLS(&tls.Config{ServerName: host}); err != nil {
_ = client.Close()
if ctx.Err() != nil {
return nil, ctx.Err()
}
return nil, fmt.Errorf("smtp starttls: %w", err)
}
return client, nil
}
func buildMIMEMessage(from, to, subject, textBody, htmlBody string) string {
boundary := randomMIMEBoundary()
return strings.Join([]string{
fmt.Sprintf("From: %s", from),
fmt.Sprintf("To: %s", to),
fmt.Sprintf("Subject: %s", subject),
"MIME-Version: 1.0",
fmt.Sprintf("Content-Type: multipart/alternative; boundary=%s", boundary),
"",
fmt.Sprintf("--%s", boundary),
"Content-Type: text/plain; charset=UTF-8",
"",
textBody,
fmt.Sprintf("--%s", boundary),
"Content-Type: text/html; charset=UTF-8",
"",
htmlBody,
fmt.Sprintf("--%s--", boundary),
"",
}, "\r\n")
}
func randomMIMEBoundary() string {
raw := make([]byte, 12)
if _, err := rand.Read(raw); err != nil {
return fmt.Sprintf("mime-boundary-%d", time.Now().UnixNano())
}
return "mime-boundary-" + hex.EncodeToString(raw)
}

View File

@@ -0,0 +1,126 @@
package smtp
import (
"context"
"strings"
"testing"
"time"
)
func TestSendReturnsConfigurationErrorWhenSMTPMissing(t *testing.T) {
mailer := NewSMTPMailer(SMTPConfig{Host: "", Port: 587, From: "noreply@example.com", Mode: SMTPModeTLS})
err := mailer.Send(context.Background(), "to@example.com", "subject", "<p>body</p>", "body")
if err == nil {
t.Fatal("expected configuration error")
}
if !strings.Contains(err.Error(), "smtp is not configured") {
t.Fatalf("unexpected error: %v", err)
}
}
func TestSendReturnsDeadlineExceededForExpiredContext(t *testing.T) {
mailer := NewSMTPMailer(SMTPConfig{Host: "smtp.example.com", Port: 587, From: "noreply@example.com", Mode: SMTPModeTLS})
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-time.Second))
defer cancel()
err := mailer.Send(ctx, "to@example.com", "subject", "<p>body</p>", "body")
if err == nil {
t.Fatal("expected context deadline exceeded")
}
if err != context.DeadlineExceeded {
t.Fatalf("expected context deadline exceeded, got %v", err)
}
}
func TestLoginAuthNextHandlesChallenges(t *testing.T) {
auth := loginAuth{username: "alice", password: "s3cret"}
proto, initial, err := auth.Start(nil)
if err != nil {
t.Fatalf("unexpected start error: %v", err)
}
if proto != "LOGIN" {
t.Fatalf("expected LOGIN auth proto, got %q", proto)
}
if len(initial) != 0 {
t.Fatalf("expected empty initial response, got %q", string(initial))
}
value, err := auth.Next([]byte("Username:"), true)
if err != nil {
t.Fatalf("unexpected username challenge error: %v", err)
}
if string(value) != "alice" {
t.Fatalf("expected username response alice, got %q", string(value))
}
value, err = auth.Next([]byte("UGFzc3dvcmQ6"), true)
if err != nil {
t.Fatalf("unexpected password challenge error: %v", err)
}
if string(value) != "s3cret" {
t.Fatalf("expected password response s3cret, got %q", string(value))
}
}
func TestLoginAuthNextHandlesTerminalAndUnexpectedChallenge(t *testing.T) {
auth := loginAuth{username: "alice", password: "s3cret"}
value, err := auth.Next([]byte("ignored"), false)
if err != nil {
t.Fatalf("expected nil error when more=false, got %v", err)
}
if value != nil {
t.Fatalf("expected nil value when more=false, got %q", string(value))
}
if _, err := auth.Next([]byte("realm"), true); err == nil {
t.Fatal("expected error for unexpected login challenge")
}
}
func TestBuildMIMEMessageContainsMultipartSections(t *testing.T) {
msg := buildMIMEMessage("from@example.com", "to@example.com", "Subject", "plain body", "<p>html body</p>")
checks := []string{
"From: from@example.com",
"To: to@example.com",
"Subject: Subject",
"MIME-Version: 1.0",
"Content-Type: text/plain; charset=UTF-8",
"plain body",
"Content-Type: text/html; charset=UTF-8",
"<p>html body</p>",
}
for _, snippet := range checks {
if !strings.Contains(msg, snippet) {
t.Fatalf("expected MIME message to contain %q", snippet)
}
}
if !strings.Contains(msg, "\r\n") {
t.Fatal("expected CRLF separators in MIME message")
}
if !strings.Contains(msg, "Content-Type: multipart/alternative; boundary=mime-boundary-") {
t.Fatalf("expected random mime boundary header, got %q", msg)
}
if !strings.Contains(msg, "--mime-boundary-") {
t.Fatalf("expected mime boundary delimiters in message, got %q", msg)
}
}
func TestDialHelpersReturnDeadlineExceededWhenDeadlineHasPassed(t *testing.T) {
deadline := time.Now().Add(-time.Second)
ctx := context.Background()
if _, err := dialPlain(ctx, "smtp.example.com:25", "smtp.example.com", deadline); err != context.DeadlineExceeded {
t.Fatalf("dialPlain expected context deadline exceeded, got %v", err)
}
if _, err := dialSSL(ctx, "smtp.example.com:465", "smtp.example.com", deadline); err != context.DeadlineExceeded {
t.Fatalf("dialSSL expected context deadline exceeded, got %v", err)
}
if _, err := dialStartTLS(ctx, "smtp.example.com:587", "smtp.example.com", deadline); err != context.DeadlineExceeded {
t.Fatalf("dialStartTLS expected context deadline exceeded, got %v", err)
}
}

View File

@@ -0,0 +1,424 @@
package smtp
import (
"bufio"
"context"
"net"
"strconv"
"strings"
"sync"
"testing"
"time"
)
type smtpTestServerConfig struct {
extensions []string
failAuth bool
failMail bool
failRcpt bool
failData bool
failQuit bool
}
type smtpTestServerState struct {
authCalls int
mailFrom string
rcptTo string
data string
}
type smtpTestServer struct {
host string
port int
state *smtpTestServerState
stop func()
}
func requireSMTPIntegration(t *testing.T) {
t.Helper()
if testing.Short() {
t.Skip("skipping smtp integration test in short mode")
}
}
func startSMTPTestServer(t *testing.T, cfg smtpTestServerConfig) *smtpTestServer {
t.Helper()
requireSMTPIntegration(t)
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen smtp test server: %v", err)
}
state := &smtpTestServerState{}
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
conn, err := ln.Accept()
if err != nil {
return
}
defer conn.Close()
r := bufio.NewReader(conn)
w := bufio.NewWriter(conn)
write := func(line string) bool {
if _, err := w.WriteString(line + "\r\n"); err != nil {
return false
}
return w.Flush() == nil
}
if !write("220 localhost ESMTP ready") {
return
}
inData := false
var dataBuf strings.Builder
for {
line, err := r.ReadString('\n')
if err != nil {
return
}
line = strings.TrimRight(line, "\r\n")
if inData {
if line == "." {
state.data = dataBuf.String()
if !write("250 2.0.0 queued") {
return
}
inData = false
continue
}
dataBuf.WriteString(line)
dataBuf.WriteString("\n")
continue
}
parts := strings.SplitN(line, " ", 2)
cmd := strings.ToUpper(parts[0])
arg := ""
if len(parts) > 1 {
arg = parts[1]
}
switch cmd {
case "EHLO", "HELO":
lines := append([]string{"localhost"}, cfg.extensions...)
for i, ext := range lines {
prefix := "250-"
if i == len(lines)-1 {
prefix = "250 "
}
if !write(prefix + ext) {
return
}
}
case "AUTH":
state.authCalls++
if cfg.failAuth {
if !write("535 5.7.8 auth failed") {
return
}
continue
}
authArg := strings.ToUpper(strings.TrimSpace(arg))
switch {
case strings.HasPrefix(authArg, "PLAIN"):
if !write("235 2.7.0 authenticated") {
return
}
case strings.HasPrefix(authArg, "LOGIN"):
if !write("334 VXNlcm5hbWU6") {
return
}
if _, err := r.ReadString('\n'); err != nil {
return
}
if !write("334 UGFzc3dvcmQ6") {
return
}
if _, err := r.ReadString('\n'); err != nil {
return
}
if !write("235 2.7.0 authenticated") {
return
}
default:
if !write("504 5.5.4 unsupported auth mechanism") {
return
}
}
case "MAIL":
state.mailFrom = arg
if cfg.failMail {
if !write("550 5.1.0 sender rejected") {
return
}
continue
}
if !write("250 2.1.0 ok") {
return
}
case "RCPT":
state.rcptTo = arg
if cfg.failRcpt {
if !write("550 5.1.1 recipient rejected") {
return
}
continue
}
if !write("250 2.1.5 ok") {
return
}
case "DATA":
if cfg.failData {
if !write("554 5.5.0 data rejected") {
return
}
continue
}
if !write("354 end data with <CR><LF>.<CR><LF>") {
return
}
inData = true
dataBuf.Reset()
case "QUIT":
if cfg.failQuit {
_ = write("554 5.5.1 quit failed")
return
}
_ = write("221 2.0.0 bye")
return
default:
if !write("250 2.0.0 ok") {
return
}
}
}
}()
tcpAddr := ln.Addr().(*net.TCPAddr)
return &smtpTestServer{
host: "127.0.0.1",
port: tcpAddr.Port,
state: state,
stop: func() {
_ = ln.Close()
wg.Wait()
},
}
}
func TestSendUnencryptedSuccess(t *testing.T) {
srv := startSMTPTestServer(t, smtpTestServerConfig{})
defer srv.stop()
mailer := NewSMTPMailer(SMTPConfig{
Host: srv.host,
Port: srv.port,
From: "noreply@example.com",
Mode: SMTPModeUnencrypted,
})
err := mailer.Send(context.Background(), "to@example.com", "subject", "<p>hello</p>", "hello")
if err != nil {
t.Fatalf("Send: %v", err)
}
if !strings.Contains(srv.state.mailFrom, "noreply@example.com") {
t.Fatalf("expected mail from to include sender, got %q", srv.state.mailFrom)
}
if !strings.Contains(srv.state.rcptTo, "to@example.com") {
t.Fatalf("expected rcpt to include recipient, got %q", srv.state.rcptTo)
}
if !strings.Contains(srv.state.data, "Subject: subject") {
t.Fatalf("expected message data to include subject, got %q", srv.state.data)
}
}
func TestSendUnencryptedErrorPaths(t *testing.T) {
tests := []struct {
name string
cfg smtpTestServerConfig
wantErr string
}{
{name: "mail from error", cfg: smtpTestServerConfig{failMail: true}, wantErr: "smtp mail from"},
{name: "rcpt error", cfg: smtpTestServerConfig{failRcpt: true}, wantErr: "smtp rcpt to"},
{name: "data error", cfg: smtpTestServerConfig{failData: true}, wantErr: "smtp data"},
{name: "quit error", cfg: smtpTestServerConfig{failQuit: true}, wantErr: "smtp quit"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
srv := startSMTPTestServer(t, tc.cfg)
defer srv.stop()
mailer := NewSMTPMailer(SMTPConfig{
Host: srv.host,
Port: srv.port,
From: "noreply@example.com",
Mode: SMTPModeUnencrypted,
})
err := mailer.Send(context.Background(), "to@example.com", "subject", "<p>hello</p>", "hello")
if err == nil || !strings.Contains(err.Error(), tc.wantErr) {
t.Fatalf("expected %q error, got %v", tc.wantErr, err)
}
})
}
}
func TestSendAuthModes(t *testing.T) {
t.Run("auth none does not issue AUTH command", func(t *testing.T) {
srv := startSMTPTestServer(t, smtpTestServerConfig{extensions: []string{"AUTH PLAIN LOGIN"}})
defer srv.stop()
mailer := NewSMTPMailer(SMTPConfig{
Host: srv.host,
Port: srv.port,
From: "noreply@example.com",
Mode: SMTPModeUnencrypted,
Username: "alice",
Password: "secret",
Auth: "none",
})
if err := mailer.Send(context.Background(), "to@example.com", "subject", "<p>hello</p>", "hello"); err != nil {
t.Fatalf("Send: %v", err)
}
if srv.state.authCalls != 0 {
t.Fatalf("expected no AUTH command, got %d", srv.state.authCalls)
}
})
t.Run("auth plain", func(t *testing.T) {
srv := startSMTPTestServer(t, smtpTestServerConfig{extensions: []string{"AUTH PLAIN"}})
defer srv.stop()
mailer := NewSMTPMailer(SMTPConfig{
Host: srv.host,
Port: srv.port,
From: "noreply@example.com",
Mode: SMTPModeUnencrypted,
Username: "alice",
Password: "secret",
Auth: "plain",
})
if err := mailer.Send(context.Background(), "to@example.com", "subject", "<p>hello</p>", "hello"); err != nil {
t.Fatalf("Send: %v", err)
}
if srv.state.authCalls == 0 {
t.Fatal("expected AUTH command for plain auth")
}
})
t.Run("auth login", func(t *testing.T) {
srv := startSMTPTestServer(t, smtpTestServerConfig{extensions: []string{"AUTH LOGIN"}})
defer srv.stop()
mailer := NewSMTPMailer(SMTPConfig{
Host: srv.host,
Port: srv.port,
From: "noreply@example.com",
Mode: SMTPModeUnencrypted,
Username: "alice",
Password: "secret",
Auth: "login",
})
if err := mailer.Send(context.Background(), "to@example.com", "subject", "<p>hello</p>", "hello"); err != nil {
t.Fatalf("Send: %v", err)
}
if srv.state.authCalls == 0 {
t.Fatal("expected AUTH command for login auth")
}
})
t.Run("auth auto failure", func(t *testing.T) {
srv := startSMTPTestServer(t, smtpTestServerConfig{failAuth: true})
defer srv.stop()
mailer := NewSMTPMailer(SMTPConfig{
Host: srv.host,
Port: srv.port,
From: "noreply@example.com",
Mode: SMTPModeUnencrypted,
Username: "alice",
Password: "secret",
Auth: "auto",
})
err := mailer.Send(context.Background(), "to@example.com", "subject", "<p>hello</p>", "hello")
if err == nil || !strings.Contains(err.Error(), "smtp auth failed") {
t.Fatalf("expected smtp auth failed error, got %v", err)
}
})
t.Run("unsupported auth method", func(t *testing.T) {
srv := startSMTPTestServer(t, smtpTestServerConfig{})
defer srv.stop()
mailer := NewSMTPMailer(SMTPConfig{
Host: srv.host,
Port: srv.port,
From: "noreply@example.com",
Mode: SMTPModeUnencrypted,
Username: "alice",
Password: "secret",
Auth: "weird",
})
err := mailer.Send(context.Background(), "to@example.com", "subject", "<p>hello</p>", "hello")
if err == nil || !strings.Contains(err.Error(), "unsupported smtp auth method") {
t.Fatalf("expected unsupported auth method error, got %v", err)
}
})
}
func TestSendTLSFailsWhenStartTLSExtensionMissing(t *testing.T) {
srv := startSMTPTestServer(t, smtpTestServerConfig{})
defer srv.stop()
mailer := NewSMTPMailer(SMTPConfig{
Host: srv.host,
Port: srv.port,
From: "noreply@example.com",
Mode: SMTPModeTLS,
})
err := mailer.Send(context.Background(), "to@example.com", "subject", "<p>hello</p>", "hello")
if err == nil || !strings.Contains(err.Error(), "does not support STARTTLS") {
t.Fatalf("expected STARTTLS unsupported error, got %v", err)
}
}
func TestDialHelpersWrapDialErrors(t *testing.T) {
requireSMTPIntegration(t)
ctx := context.Background()
deadline := time.Now().Add(500 * time.Millisecond)
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen ephemeral: %v", err)
}
addr := ln.Addr().String()
_ = ln.Close()
host, portStr, err := net.SplitHostPort(addr)
if err != nil {
t.Fatalf("split host port: %v", err)
}
port, err := strconv.Atoi(portStr)
if err != nil {
t.Fatalf("parse port: %v", err)
}
unusedAddr := net.JoinHostPort(host, strconv.Itoa(port))
if _, err := dialPlain(ctx, unusedAddr, host, deadline); err == nil || !strings.Contains(err.Error(), "smtp plain dial") {
t.Fatalf("expected wrapped plain dial error, got %v", err)
}
if _, err := dialSSL(ctx, unusedAddr, host, deadline); err == nil || !strings.Contains(err.Error(), "smtp ssl dial") {
t.Fatalf("expected wrapped ssl dial error, got %v", err)
}
}

47
smtp/smtp_mailer_test.go Normal file
View File

@@ -0,0 +1,47 @@
package smtp
import (
"context"
"errors"
"testing"
"time"
)
func TestSendHonorsCanceledContext(t *testing.T) {
mailer := NewSMTPMailer(SMTPConfig{
Host: "smtp.example.com",
Port: 587,
From: "noreply@example.com",
Mode: SMTPModeTLS,
})
ctx, cancel := context.WithCancel(context.Background())
cancel()
err := mailer.Send(ctx, "to@example.com", "subject", "<p>hello</p>", "hello")
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected %v, got %v", context.Canceled, err)
}
}
func TestSMTPSendDeadlineUsesContextDeadline(t *testing.T) {
now := time.Date(2026, time.January, 1, 12, 0, 0, 0, time.UTC)
want := now.Add(2 * time.Minute)
ctx, cancel := context.WithDeadline(context.Background(), want)
defer cancel()
got := smtpSendDeadline(ctx, now)
if !got.Equal(want) {
t.Fatalf("expected context deadline %v, got %v", want, got)
}
}
func TestSMTPSendDeadlineUsesDefaultWhenContextHasNoDeadline(t *testing.T) {
now := time.Date(2026, time.January, 1, 12, 0, 0, 0, time.UTC)
want := now.Add(defaultSMTPOperationTimeout)
got := smtpSendDeadline(context.Background(), now)
if !got.Equal(want) {
t.Fatalf("expected default deadline %v, got %v", want, got)
}
}

70
worker/poller.go Normal file
View File

@@ -0,0 +1,70 @@
package worker
import (
"context"
"log/slog"
"time"
)
const defaultTaskName = "batch worker"
// BatchRunner executes one batch and returns the number of processed records.
// Returning 0 stops the current drain cycle.
type BatchRunner func(ctx context.Context, limit int) (int, error)
func Run(ctx context.Context, runner BatchRunner, interval time.Duration, batchSize int, logger *slog.Logger, taskName string) {
if runner == nil || interval <= 0 || batchSize <= 0 {
return
}
if logger == nil {
logger = slog.Default()
}
if taskName == "" {
taskName = defaultTaskName
}
RunOnce(ctx, runner, batchSize, logger, taskName)
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
RunOnce(ctx, runner, batchSize, logger, taskName)
}
}
}
func RunOnce(ctx context.Context, runner BatchRunner, batchSize int, logger *slog.Logger, taskName string) {
if runner == nil || batchSize <= 0 {
return
}
if logger == nil {
logger = slog.Default()
}
if taskName == "" {
taskName = defaultTaskName
}
totalProcessed := 0
for {
processed, err := runner(ctx, batchSize)
if err != nil {
if ctx.Err() == nil {
logger.Warn(taskName+" run failed", "error", err)
}
return
}
if processed <= 0 {
break
}
totalProcessed += processed
}
if totalProcessed > 0 {
logger.Info(taskName+" completed", "processed", totalProcessed)
}
}

125
worker/poller_test.go Normal file
View File

@@ -0,0 +1,125 @@
package worker
import (
"context"
"errors"
"io"
"log/slog"
"sync"
"testing"
"time"
)
type fakeScheduledPostPromoter struct {
mu sync.Mutex
calls int
limits []int
results []int
errAt int
callHook func(int)
}
func (f *fakeScheduledPostPromoter) PromoteDueScheduled(_ context.Context, limit int) (int, error) {
f.mu.Lock()
defer f.mu.Unlock()
f.calls++
callNo := f.calls
f.limits = append(f.limits, limit)
if f.callHook != nil {
f.callHook(callNo)
}
if f.errAt > 0 && callNo == f.errAt {
return 0, errors.New("boom")
}
if len(f.results) == 0 {
return 0, nil
}
result := f.results[0]
f.results = f.results[1:]
return result, nil
}
func (f *fakeScheduledPostPromoter) Calls() int {
f.mu.Lock()
defer f.mu.Unlock()
return f.calls
}
func (f *fakeScheduledPostPromoter) Limits() []int {
f.mu.Lock()
defer f.mu.Unlock()
out := make([]int, len(f.limits))
copy(out, f.limits)
return out
}
func newDiscardLogger() *slog.Logger {
return slog.New(slog.NewTextHandler(io.Discard, nil))
}
func TestPromoteDueScheduledOnceBatchesUntilEmpty(t *testing.T) {
repo := &fakeScheduledPostPromoter{results: []int{2, 3, 0}}
RunOnce(context.Background(), repo.PromoteDueScheduled, 50, newDiscardLogger(), "scheduled post promotion")
if got := repo.Calls(); got != 3 {
t.Fatalf("expected 3 calls, got %d", got)
}
limits := repo.Limits()
if len(limits) != 3 || limits[0] != 50 || limits[1] != 50 || limits[2] != 50 {
t.Fatalf("unexpected limits: %+v", limits)
}
}
func TestPromoteDueScheduledOnceStopsOnError(t *testing.T) {
repo := &fakeScheduledPostPromoter{
results: []int{4, 4},
errAt: 2,
}
RunOnce(context.Background(), repo.PromoteDueScheduled, 25, newDiscardLogger(), "scheduled post promotion")
if got := repo.Calls(); got != 2 {
t.Fatalf("expected 2 calls before stop, got %d", got)
}
}
func TestRunScheduledPostPromoterStopsOnContextCancel(t *testing.T) {
callCh := make(chan int, 8)
repo := &fakeScheduledPostPromoter{
results: []int{0, 0, 0, 0},
callHook: func(call int) {
callCh <- call
},
}
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
Run(ctx, repo.PromoteDueScheduled, 10*time.Millisecond, 10, newDiscardLogger(), "scheduled post promotion")
close(done)
}()
timeout := time.After(300 * time.Millisecond)
for {
select {
case <-callCh:
if repo.Calls() >= 2 {
cancel()
select {
case <-done:
return
case <-time.After(200 * time.Millisecond):
t.Fatal("promoter did not stop after context cancel")
}
}
case <-timeout:
cancel()
t.Fatal("timed out waiting for promoter calls")
}
}
}