Skip to content

Commit

Permalink
Add restore validation
Browse files Browse the repository at this point in the history
  • Loading branch information
dmashuda committed Feb 6, 2025
1 parent 2bc4184 commit f440629
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 33 deletions.
8 changes: 7 additions & 1 deletion internal/dev_server/api/api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,15 @@ paths:
post:
summary: post backup
operationId: restoreBackup
requestBody:
content:
application/vnd.sqlite3:
schema:
type: string
format: binary
responses:
200:
$ref: "#/components/responses/DbBackup"
description: 'Backup restored'
/dev/projects:
get:
summary: lists all projects that have been configured for the dev server
Expand Down
6 changes: 4 additions & 2 deletions internal/dev_server/api/restore_backup.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package api

import "context"
import (
"context"
)

func (s server) RestoreBackup(ctx context.Context, request RestoreBackupRequestObject) (RestoreBackupResponseObject, error) {
//TODO implement me
request.Body

Check failure on line 8 in internal/dev_server/api/restore_backup.go

View workflow job for this annotation

GitHub Actions / build (1.22.5)

request.Body (variable of type io.Reader) is not used

Check failure on line 8 in internal/dev_server/api/restore_backup.go

View workflow job for this annotation

GitHub Actions / build (1.21.12)

request.Body (variable of type io.Reader) is not used
panic("implement me")
}
19 changes: 6 additions & 13 deletions internal/dev_server/api/server.gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion internal/dev_server/db/backup/sqlite_backup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func TestDbBackup(t *testing.T) {
require.NoError(t, err)
require.Len(t, originalResults, dataSize)

manager := backup.NewManager(dbPath, "main", "ld_cli_backup_test*.bak")
manager := backup.NewManager(dbPath, "main", "ld_cli_backup_test*.bak", "ld_cli_restore_test*.bak")

for i := 0; i < 5; i++ {
t.Run("Backup_"+strconv.Itoa(i), func(t *testing.T) {
Expand Down
75 changes: 62 additions & 13 deletions internal/dev_server/db/backup/sqllite_backup.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
sqllite "github.com/mattn/go-sqlite3"
"github.com/pkg/errors"
"io"
"log"
"os"
"sync"
Expand All @@ -15,23 +16,27 @@ import (
var c atomic.Int32

type Manager struct {
dbPath string
dbName string
backupFilePattern string
driverName string
mutex sync.Mutex
conns []*sqllite.SQLiteConn
dbPath string
dbName string
backupFilePattern string
restoreFilePattern string
driverName string
validationQueries []string
mutex sync.Mutex
conns []*sqllite.SQLiteConn
}

func NewManager(dbPath string, dbName string, backupFilePattern string) *Manager {
func NewManager(dbPath string, dbName string, backupFilePattern string, restoreFilePattern string) *Manager {
count := c.Add(1)
m := &Manager{
dbPath: dbPath,
dbName: dbName,
backupFilePattern: backupFilePattern,
driverName: fmt.Sprintf("sqlite3-backups-%d", count),
conns: make([]*sqllite.SQLiteConn, 0),
mutex: sync.Mutex{},
dbPath: dbPath,
dbName: dbName,
backupFilePattern: backupFilePattern,
restoreFilePattern: restoreFilePattern,
driverName: fmt.Sprintf("sqlite3-backups-%d", count),
conns: make([]*sqllite.SQLiteConn, 0),
validationQueries: make([]string, 0),
mutex: sync.Mutex{},
}
sql.Register(m.driverName, &sqllite.SQLiteDriver{
ConnectHook: func(conn *sqllite.SQLiteConn) error {
Expand All @@ -42,6 +47,13 @@ func NewManager(dbPath string, dbName string, backupFilePattern string) *Manager
return m
}

// AddValidationQueries Adds queries to run on a restored database to ensure meets some criteria
// These queries should cause db.Exec to return error if the database imported is invalid.
// For example, if the database does not have a vital table
func (m *Manager) AddValidationQueries(queries ...string) {
m.validationQueries = append(m.validationQueries, queries...)
}

func (m *Manager) resetConnections() {
m.conns = make([]*sqllite.SQLiteConn, 0)
}
Expand All @@ -61,6 +73,43 @@ func (m *Manager) connectToDb(ctx context.Context, path string) (*sql.DB, error)
return db, nil
}

// RestoreToFile returns a string path of the sqlite database restored from the stream
func (m *Manager) RestoreToFile(ctx context.Context, stream io.ReadCloser) (string, error) {
m.mutex.Lock()
defer m.mutex.Unlock()

// Make a temp file to copy into
tempFile, err := os.CreateTemp("", m.restoreFilePattern)
if err != nil {
return "", errors.Wrapf(err, "unable to create temp file")
}
_, err = io.Copy(tempFile, stream)
if err != nil {
return "", errors.Wrapf(err, "unable to write to temp file")
}

// connect to db
copiedDb, err := m.connectToDb(ctx, m.dbPath)
if err != nil {
return "", errors.Wrapf(err, "unable to connect to database")
}
defer func(copiedDb *sql.DB) {
err := copiedDb.Close()
if err != nil {
log.Println(err)
}
}(copiedDb)

for _, query := range m.validationQueries {
_, err := copiedDb.ExecContext(ctx, query)
if err != nil {
return "", errors.Wrapf(err, "restored db failed validation query: %s", query)
}
}

return tempFile.Name(), nil
}

// MakeBackupFile returns a string path of the sqlite database backup
func (m *Manager) MakeBackupFile(ctx context.Context) (string, error) {
m.mutex.Lock()
Expand Down
14 changes: 11 additions & 3 deletions internal/dev_server/db/sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -350,8 +350,9 @@ func (s Sqlite) DeactivateOverride(ctx context.Context, projectKey, flagKey stri
}

func (s Sqlite) RestoreBackup(ctx context.Context, stream io.ReadCloser) (string, error) {
//TODO implement me
panic("implement me")
filepath, err := s.backupManager.RestoreToFile(ctx, stream)

return filepath, err
}

func (s Sqlite) CreateBackup(ctx context.Context) (io.ReadCloser, int64, error) {
Expand All @@ -369,7 +370,8 @@ func (s Sqlite) CreateBackup(ctx context.Context) (io.ReadCloser, int64, error)

func NewSqlite(ctx context.Context, dbPath string) (Sqlite, error) {
store := new(Sqlite)
store.backupManager = backup.NewManager(dbPath, "main", "ld_cli_*.bak")
store.backupManager = backup.NewManager(dbPath, "main", "ld_cli_*.bak", "ld_cli_restore_*.bak")
store.backupManager.AddValidationQueries(validationQueries...)
db, err := sql.Open("sqlite3", dbPath)
if err != nil {
return Sqlite{}, err
Expand All @@ -382,6 +384,12 @@ func NewSqlite(ctx context.Context, dbPath string) (Sqlite, error) {
return *store, nil
}

var validationQueries = []string{
"SELECT COUNT(1) from projects",
"SELECT COUNT(1) from overrides",
"SELECT COUNT(1) from available_variations",
}

func (s Sqlite) runMigrations(ctx context.Context) error {
tx, err := s.database.BeginTx(ctx, nil)
if err != nil {
Expand Down

0 comments on commit f440629

Please sign in to comment.