From 6f1f48342e36fa00a3e89d73695922c47aa94987 Mon Sep 17 00:00:00 2001 From: Jordan Krage Date: Thu, 30 Jan 2025 14:29:59 -0600 Subject: [PATCH] pkg/sqlutil/sqltest: add CreateOrReplace (#1018) --- pkg/sqlutil/sqltest/sqltest.go | 66 ++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/pkg/sqlutil/sqltest/sqltest.go b/pkg/sqlutil/sqltest/sqltest.go index dea5eb368..36e7c25fe 100644 --- a/pkg/sqlutil/sqltest/sqltest.go +++ b/pkg/sqlutil/sqltest/sqltest.go @@ -1,7 +1,12 @@ package sqltest import ( + "database/sql" + "errors" + "fmt" + "net/url" "os" + "strings" "testing" "github.com/google/uuid" @@ -47,3 +52,64 @@ func SkipInMemory(t *testing.T) { t.Skip("skipping test due to in-memory db") } } + +// CreateOrReplace creates a new database with the given name (optionally from template), and schedules it to be dropped +// after test completion. +func CreateOrReplace(t testing.TB, u url.URL, dbName string, template string) url.URL { + if u.Path == "" { + t.Fatal("path missing from database URL") + } + + if l := len(dbName); l > 63 { + t.Fatalf("dbName %v too long (%d), max is 63 bytes", dbName, l) + } + // Cannot drop test database if we are connected to it, so we must connect + // to a different one. 'postgres' should be present on all postgres installations + u.Path = "/postgres" + db, err := sql.Open(pg.DriverPostgres, u.String()) + if err != nil { + t.Fatalf("in order to drop the test database, we need to connect to a separate database"+ + " called 'postgres'. But we are unable to open 'postgres' database: %+v\n", err) + } + defer db.Close() + + _, err = db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbName)) + if err != nil { + t.Fatalf("unable to drop postgres migrations test database: %v", err) + } + if template != "" { + _, err = db.Exec(fmt.Sprintf("CREATE DATABASE %s WITH TEMPLATE %s", dbName, template)) + } else { + _, err = db.Exec(fmt.Sprintf("CREATE DATABASE %s", dbName)) + } + if err != nil { + t.Fatalf("unable to create postgres test database with name '%s': %v", dbName, err) + } + u.Path = fmt.Sprintf("/%s", dbName) + t.Cleanup(func() { assert.NoError(t, drop(u)) }) + return u +} + +// drop drops the database at the given URL. +func drop(dbURL url.URL) error { + if dbURL.Path == "" { + return errors.New("path missing from database URL") + } + dbname := strings.TrimPrefix(dbURL.Path, "/") + + // Cannot drop test database if we are connected to it, so we must connect + // to a different one. 'postgres' should be present on all postgres installations + dbURL.Path = "/postgres" + db, err := sql.Open(pg.DriverPostgres, dbURL.String()) + if err != nil { + return fmt.Errorf("in order to drop the test database, we need to connect to a separate database"+ + " called 'postgres'. But we are unable to open 'postgres' database: %+v\n", err) + } + defer db.Close() + + _, err = db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname)) + if err != nil { + return fmt.Errorf("unable to drop postgres migrations test database: %v", err) + } + return nil +}