add core lib
This commit is contained in:
150
migrate/migrate.go
Normal file
150
migrate/migrate.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package migrate
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
_ "github.com/golang-migrate/migrate/v4/database/postgres"
|
||||
_ "github.com/golang-migrate/migrate/v4/source/file"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultMigrationsPath = "file://migrations"
|
||||
)
|
||||
|
||||
type MigrationConfig struct {
|
||||
Path string
|
||||
LockTimeout time.Duration
|
||||
StartupTimeout time.Duration
|
||||
}
|
||||
|
||||
type MigrationStatus struct {
|
||||
Version uint
|
||||
Dirty bool
|
||||
}
|
||||
|
||||
func RunMigrationsUp(databaseURL string, cfg MigrationConfig) error {
|
||||
exec := func() error {
|
||||
m, err := newMigrator(databaseURL, cfg.Path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer closeMigrator(m)
|
||||
|
||||
if err := m.Up(); err != nil && !errors.Is(err, migrate.ErrNoChange) {
|
||||
return fmt.Errorf("run migrations up: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return runWithTimeout(cfg.StartupTimeout, exec)
|
||||
}
|
||||
|
||||
func RunMigrationsDown(databaseURL string, cfg MigrationConfig, steps int) error {
|
||||
m, err := newMigrator(databaseURL, cfg.Path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer closeMigrator(m)
|
||||
|
||||
if steps > 0 {
|
||||
if err := m.Steps(-steps); err != nil && !errors.Is(err, migrate.ErrNoChange) {
|
||||
return fmt.Errorf("run migrations down %d steps: %w", steps, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := m.Down(); err != nil && !errors.Is(err, migrate.ErrNoChange) {
|
||||
return fmt.Errorf("run migrations down all: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ForceMigrationVersion(databaseURL string, cfg MigrationConfig, version int) error {
|
||||
m, err := newMigrator(databaseURL, cfg.Path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer closeMigrator(m)
|
||||
|
||||
if err := m.Force(version); err != nil {
|
||||
return fmt.Errorf("force migration version %d: %w", version, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetMigrationStatus(databaseURL string, cfg MigrationConfig) (MigrationStatus, error) {
|
||||
m, err := newMigrator(databaseURL, cfg.Path)
|
||||
if err != nil {
|
||||
return MigrationStatus{}, err
|
||||
}
|
||||
defer closeMigrator(m)
|
||||
|
||||
version, dirty, err := m.Version()
|
||||
if errors.Is(err, migrate.ErrNilVersion) {
|
||||
return MigrationStatus{}, nil
|
||||
}
|
||||
if err != nil {
|
||||
return MigrationStatus{}, fmt.Errorf("read migration version: %w", err)
|
||||
}
|
||||
|
||||
return MigrationStatus{Version: version, Dirty: dirty}, nil
|
||||
}
|
||||
|
||||
func ResolveMigrationsPath(pathValue string) string {
|
||||
if pathValue == "" {
|
||||
pathValue = defaultMigrationsPath
|
||||
}
|
||||
if filepath.IsAbs(pathValue) {
|
||||
return "file://" + pathValue
|
||||
}
|
||||
if len(pathValue) >= len("file://") && pathValue[:len("file://")] == "file://" {
|
||||
return pathValue
|
||||
}
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return defaultMigrationsPath
|
||||
}
|
||||
return "file://" + filepath.Join(wd, pathValue)
|
||||
}
|
||||
|
||||
func newMigrator(databaseURL string, path string) (*migrate.Migrate, error) {
|
||||
migrationsPath := ResolveMigrationsPath(path)
|
||||
m, err := migrate.New(migrationsPath, databaseURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open migration driver (%s): %w", migrationsPath, err)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func closeMigrator(m *migrate.Migrate) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
sourceErr, dbErr := m.Close()
|
||||
if sourceErr != nil || dbErr != nil {
|
||||
_ = sourceErr
|
||||
_ = dbErr
|
||||
}
|
||||
}
|
||||
|
||||
func runWithTimeout(timeout time.Duration, fn func() error) error {
|
||||
if timeout <= 0 {
|
||||
return fn()
|
||||
}
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- fn()
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
return err
|
||||
case <-time.After(timeout):
|
||||
return fmt.Errorf("migration timed out after %s", timeout)
|
||||
}
|
||||
}
|
||||
67
migrate/migrate_integration_test.go
Normal file
67
migrate/migrate_integration_test.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package migrate
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRunMigrationsUpTwice(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
databaseURL := os.Getenv("INTEGRATION_DB_URL")
|
||||
if databaseURL == "" {
|
||||
t.Skip("set INTEGRATION_DB_URL to run migration integration tests")
|
||||
}
|
||||
|
||||
cfg := MigrationConfig{Path: resolveMigrationsPath()}
|
||||
if err := RunMigrationsUp(databaseURL, cfg); err != nil {
|
||||
t.Fatalf("first migrate up: %v", err)
|
||||
}
|
||||
if err := RunMigrationsUp(databaseURL, cfg); err != nil {
|
||||
t.Fatalf("second migrate up: %v", err)
|
||||
}
|
||||
|
||||
status, err := GetMigrationStatus(databaseURL, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("get migration status: %v", err)
|
||||
}
|
||||
if status.Dirty {
|
||||
t.Fatalf("expected non-dirty migration status")
|
||||
}
|
||||
if status.Version == 0 {
|
||||
t.Fatalf("expected migration version > 0 after up")
|
||||
}
|
||||
}
|
||||
|
||||
func resolveMigrationsPath() string {
|
||||
configured := os.Getenv("DB_MIGRATIONS_PATH")
|
||||
if configured == "" {
|
||||
configured = "migrations"
|
||||
}
|
||||
|
||||
if filepath.IsAbs(configured) {
|
||||
return configured
|
||||
}
|
||||
if len(configured) >= len("file://") && configured[:len("file://")] == "file://" {
|
||||
return configured
|
||||
}
|
||||
if _, err := os.Stat(configured); err == nil {
|
||||
return configured
|
||||
}
|
||||
|
||||
_, thisFile, _, ok := runtime.Caller(0)
|
||||
if !ok {
|
||||
return configured
|
||||
}
|
||||
moduleRoot := filepath.Clean(filepath.Join(filepath.Dir(thisFile), "..", ".."))
|
||||
candidate := filepath.Join(moduleRoot, configured)
|
||||
if _, err := os.Stat(candidate); err == nil {
|
||||
return candidate
|
||||
}
|
||||
|
||||
return configured
|
||||
}
|
||||
77
migrate/migrate_unit_test.go
Normal file
77
migrate/migrate_unit_test.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package migrate
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestResolveMigrationsPath(t *testing.T) {
|
||||
if got := ResolveMigrationsPath(""); got != "file://migrations" && !strings.HasSuffix(got, "/migrations") {
|
||||
t.Fatalf("expected default migrations path, got %q", got)
|
||||
}
|
||||
|
||||
if got := ResolveMigrationsPath("file:///tmp/migs"); got != "file:///tmp/migs" {
|
||||
t.Fatalf("expected passthrough file URI, got %q", got)
|
||||
}
|
||||
|
||||
abs := filepath.Join(string(os.PathSeparator), "tmp", "migs")
|
||||
if got := ResolveMigrationsPath(abs); got != "file://"+abs {
|
||||
t.Fatalf("expected absolute path conversion, got %q", got)
|
||||
}
|
||||
|
||||
rel := "migrations"
|
||||
got := ResolveMigrationsPath(rel)
|
||||
if !strings.HasPrefix(got, "file://") || !strings.HasSuffix(got, rel) {
|
||||
t.Fatalf("expected file URI ending with %q, got %q", rel, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunWithTimeout(t *testing.T) {
|
||||
if err := runWithTimeout(0, func() error { return nil }); err != nil {
|
||||
t.Fatalf("expected no error for immediate execution, got %v", err)
|
||||
}
|
||||
|
||||
if err := runWithTimeout(50*time.Millisecond, func() error {
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatalf("expected fn to finish before timeout, got %v", err)
|
||||
}
|
||||
|
||||
err := runWithTimeout(10*time.Millisecond, func() error {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
return nil
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "timed out") {
|
||||
t.Fatalf("expected timeout error, got %v", err)
|
||||
}
|
||||
|
||||
expected := errors.New("boom")
|
||||
err = runWithTimeout(20*time.Millisecond, func() error { return expected })
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatalf("expected function error propagation, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrationAPIs_InvalidSourcePath(t *testing.T) {
|
||||
cfg := MigrationConfig{Path: "definitely-missing-migrations-dir"}
|
||||
|
||||
if err := RunMigrationsUp("postgres://ignored", cfg); err == nil {
|
||||
t.Fatalf("expected RunMigrationsUp error for missing source path")
|
||||
}
|
||||
if err := RunMigrationsDown("postgres://ignored", cfg, 1); err == nil {
|
||||
t.Fatalf("expected RunMigrationsDown error for missing source path")
|
||||
}
|
||||
if err := ForceMigrationVersion("postgres://ignored", cfg, 1); err == nil {
|
||||
t.Fatalf("expected ForceMigrationVersion error for missing source path")
|
||||
}
|
||||
if _, err := GetMigrationStatus("postgres://ignored", cfg); err == nil {
|
||||
t.Fatalf("expected GetMigrationStatus error for missing source path")
|
||||
}
|
||||
|
||||
closeMigrator(nil)
|
||||
}
|
||||
Reference in New Issue
Block a user