package migrate import ( "errors" "fmt" "os" "path/filepath" "time" "github.com/golang-migrate/migrate/v4" _ "github.com/golang-migrate/migrate/v4/database/postgres" _ "github.com/golang-migrate/migrate/v4/source/file" ) const ( defaultMigrationsPath = "file://migrations" ) type MigrationConfig struct { Path string LockTimeout time.Duration StartupTimeout time.Duration } type MigrationStatus struct { Version uint Dirty bool } func RunMigrationsUp(databaseURL string, cfg MigrationConfig) error { exec := func() error { m, err := newMigrator(databaseURL, cfg.Path) if err != nil { return err } defer closeMigrator(m) if err := m.Up(); err != nil && !errors.Is(err, migrate.ErrNoChange) { return fmt.Errorf("run migrations up: %w", err) } return nil } return runWithTimeout(cfg.StartupTimeout, exec) } func RunMigrationsDown(databaseURL string, cfg MigrationConfig, steps int) error { m, err := newMigrator(databaseURL, cfg.Path) if err != nil { return err } defer closeMigrator(m) if steps > 0 { if err := m.Steps(-steps); err != nil && !errors.Is(err, migrate.ErrNoChange) { return fmt.Errorf("run migrations down %d steps: %w", steps, err) } return nil } if err := m.Down(); err != nil && !errors.Is(err, migrate.ErrNoChange) { return fmt.Errorf("run migrations down all: %w", err) } return nil } func ForceMigrationVersion(databaseURL string, cfg MigrationConfig, version int) error { m, err := newMigrator(databaseURL, cfg.Path) if err != nil { return err } defer closeMigrator(m) if err := m.Force(version); err != nil { return fmt.Errorf("force migration version %d: %w", version, err) } return nil } func GetMigrationStatus(databaseURL string, cfg MigrationConfig) (MigrationStatus, error) { m, err := newMigrator(databaseURL, cfg.Path) if err != nil { return MigrationStatus{}, err } defer closeMigrator(m) version, dirty, err := m.Version() if errors.Is(err, migrate.ErrNilVersion) { return MigrationStatus{}, nil } if err != nil { return MigrationStatus{}, fmt.Errorf("read migration version: %w", err) } return MigrationStatus{Version: version, Dirty: dirty}, nil } func ResolveMigrationsPath(pathValue string) string { if pathValue == "" { pathValue = defaultMigrationsPath } if filepath.IsAbs(pathValue) { return "file://" + pathValue } if len(pathValue) >= len("file://") && pathValue[:len("file://")] == "file://" { return pathValue } wd, err := os.Getwd() if err != nil { return defaultMigrationsPath } return "file://" + filepath.Join(wd, pathValue) } func newMigrator(databaseURL string, path string) (*migrate.Migrate, error) { migrationsPath := ResolveMigrationsPath(path) m, err := migrate.New(migrationsPath, databaseURL) if err != nil { return nil, fmt.Errorf("open migration driver (%s): %w", migrationsPath, err) } return m, nil } func closeMigrator(m *migrate.Migrate) { if m == nil { return } sourceErr, dbErr := m.Close() if sourceErr != nil || dbErr != nil { _ = sourceErr _ = dbErr } } func runWithTimeout(timeout time.Duration, fn func() error) error { if timeout <= 0 { return fn() } done := make(chan error, 1) go func() { done <- fn() }() select { case err := <-done: return err case <-time.After(timeout): return fmt.Errorf("migration timed out after %s", timeout) } }