151 lines
3.3 KiB
Go
151 lines
3.3 KiB
Go
package migrate
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"time"
|
|
|
|
"github.com/golang-migrate/migrate/v4"
|
|
_ "github.com/golang-migrate/migrate/v4/database/postgres"
|
|
_ "github.com/golang-migrate/migrate/v4/source/file"
|
|
)
|
|
|
|
const (
|
|
defaultMigrationsPath = "file://migrations"
|
|
)
|
|
|
|
type MigrationConfig struct {
|
|
Path string
|
|
LockTimeout time.Duration
|
|
StartupTimeout time.Duration
|
|
}
|
|
|
|
type MigrationStatus struct {
|
|
Version uint
|
|
Dirty bool
|
|
}
|
|
|
|
func RunMigrationsUp(databaseURL string, cfg MigrationConfig) error {
|
|
exec := func() error {
|
|
m, err := newMigrator(databaseURL, cfg.Path)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer closeMigrator(m)
|
|
|
|
if err := m.Up(); err != nil && !errors.Is(err, migrate.ErrNoChange) {
|
|
return fmt.Errorf("run migrations up: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
return runWithTimeout(cfg.StartupTimeout, exec)
|
|
}
|
|
|
|
func RunMigrationsDown(databaseURL string, cfg MigrationConfig, steps int) error {
|
|
m, err := newMigrator(databaseURL, cfg.Path)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer closeMigrator(m)
|
|
|
|
if steps > 0 {
|
|
if err := m.Steps(-steps); err != nil && !errors.Is(err, migrate.ErrNoChange) {
|
|
return fmt.Errorf("run migrations down %d steps: %w", steps, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
if err := m.Down(); err != nil && !errors.Is(err, migrate.ErrNoChange) {
|
|
return fmt.Errorf("run migrations down all: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func ForceMigrationVersion(databaseURL string, cfg MigrationConfig, version int) error {
|
|
m, err := newMigrator(databaseURL, cfg.Path)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer closeMigrator(m)
|
|
|
|
if err := m.Force(version); err != nil {
|
|
return fmt.Errorf("force migration version %d: %w", version, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func GetMigrationStatus(databaseURL string, cfg MigrationConfig) (MigrationStatus, error) {
|
|
m, err := newMigrator(databaseURL, cfg.Path)
|
|
if err != nil {
|
|
return MigrationStatus{}, err
|
|
}
|
|
defer closeMigrator(m)
|
|
|
|
version, dirty, err := m.Version()
|
|
if errors.Is(err, migrate.ErrNilVersion) {
|
|
return MigrationStatus{}, nil
|
|
}
|
|
if err != nil {
|
|
return MigrationStatus{}, fmt.Errorf("read migration version: %w", err)
|
|
}
|
|
|
|
return MigrationStatus{Version: version, Dirty: dirty}, nil
|
|
}
|
|
|
|
func ResolveMigrationsPath(pathValue string) string {
|
|
if pathValue == "" {
|
|
pathValue = defaultMigrationsPath
|
|
}
|
|
if filepath.IsAbs(pathValue) {
|
|
return "file://" + pathValue
|
|
}
|
|
if len(pathValue) >= len("file://") && pathValue[:len("file://")] == "file://" {
|
|
return pathValue
|
|
}
|
|
wd, err := os.Getwd()
|
|
if err != nil {
|
|
return defaultMigrationsPath
|
|
}
|
|
return "file://" + filepath.Join(wd, pathValue)
|
|
}
|
|
|
|
func newMigrator(databaseURL string, path string) (*migrate.Migrate, error) {
|
|
migrationsPath := ResolveMigrationsPath(path)
|
|
m, err := migrate.New(migrationsPath, databaseURL)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("open migration driver (%s): %w", migrationsPath, err)
|
|
}
|
|
return m, nil
|
|
}
|
|
|
|
func closeMigrator(m *migrate.Migrate) {
|
|
if m == nil {
|
|
return
|
|
}
|
|
sourceErr, dbErr := m.Close()
|
|
if sourceErr != nil || dbErr != nil {
|
|
_ = sourceErr
|
|
_ = dbErr
|
|
}
|
|
}
|
|
|
|
func runWithTimeout(timeout time.Duration, fn func() error) error {
|
|
if timeout <= 0 {
|
|
return fn()
|
|
}
|
|
done := make(chan error, 1)
|
|
go func() {
|
|
done <- fn()
|
|
}()
|
|
|
|
select {
|
|
case err := <-done:
|
|
return err
|
|
case <-time.After(timeout):
|
|
return fmt.Errorf("migration timed out after %s", timeout)
|
|
}
|
|
}
|