Skip to content

Commit

Permalink
feat: add support for using postgres for the backend DB
Browse files Browse the repository at this point in the history
  • Loading branch information
cheahjs committed Apr 27, 2024
1 parent f8f9f27 commit e91a49c
Show file tree
Hide file tree
Showing 15 changed files with 329 additions and 18 deletions.
13 changes: 12 additions & 1 deletion backend/apps/litellm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from config import SRC_LOG_LEVELS, ENV
from constants import MESSAGES

import os

log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["LITELLM"])

Expand Down Expand Up @@ -62,6 +64,13 @@
# Global variable to store the subprocess reference
background_process = None

CONFLICT_ENV_VARS = [
# Uvicorn uses PORT, so LiteLLM might use it as well
"PORT",
# LiteLLM uses DATABASE_URL for Prisma connections
"DATABASE_URL",
]


async def run_background_process(command):
global background_process
Expand All @@ -70,9 +79,11 @@ async def run_background_process(command):
try:
# Log the command to be executed
log.info(f"Executing command: {command}")
# Filter environment variables known to conflict with litellm
env = {k: v for k, v in os.environ.items() if k not in CONFLICT_ENV_VARS}
# Execute the command and create a subprocess
process = await asyncio.create_subprocess_exec(
*command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
*command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env
)
background_process = process
log.info("Subprocess started successfully.")
Expand Down
9 changes: 5 additions & 4 deletions backend/apps/web/internal/db.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from peewee import *
from peewee_migrate import Router
from config import SRC_LOG_LEVELS, DATA_DIR
from playhouse.db_url import connect
from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL
import os
import logging

Expand All @@ -11,12 +12,12 @@
if os.path.exists(f"{DATA_DIR}/ollama.db"):
# Rename the file
os.rename(f"{DATA_DIR}/ollama.db", f"{DATA_DIR}/webui.db")
log.info("File renamed successfully.")
log.info("Database migrated from Ollama-WebUI successfully.")
else:
pass


DB = SqliteDatabase(f"{DATA_DIR}/webui.db")
DB = connect(DATABASE_URL)
log.info(f"Connected to a {DB.__class__.__name__} database.")
router = Router(DB, migrate_dir="apps/web/internal/migrations", logger=log)
router.run()
DB.connect(reuse_if_open=True)
105 changes: 105 additions & 0 deletions backend/apps/web/internal/migrations/001_initial_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,18 @@
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""

# We perform different migrations for SQLite and other databases
# This is because SQLite is very loose with enforcing its schema, and trying to migrate other databases like SQLite
# will require per-database SQL queries.
# Instead, we assume that because external DB support was added at a later date, it is safe to assume a newer base
# schema instead of trying to migrate from an older schema.
if isinstance(database, pw.SqliteDatabase):
migrate_sqlite(migrator, database, fake=fake)
else:
migrate_external(migrator, database, fake=fake)


def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
@migrator.create_model
class Auth(pw.Model):
id = pw.CharField(max_length=255, unique=True)
Expand Down Expand Up @@ -129,6 +141,99 @@ class Meta:
table_name = "user"


def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
@migrator.create_model
class Auth(pw.Model):
id = pw.CharField(max_length=255, unique=True)
email = pw.CharField(max_length=255)
password = pw.TextField()
active = pw.BooleanField()

class Meta:
table_name = "auth"

@migrator.create_model
class Chat(pw.Model):
id = pw.CharField(max_length=255, unique=True)
user_id = pw.CharField(max_length=255)
title = pw.TextField()
chat = pw.TextField()
timestamp = pw.BigIntegerField()

class Meta:
table_name = "chat"

@migrator.create_model
class ChatIdTag(pw.Model):
id = pw.CharField(max_length=255, unique=True)
tag_name = pw.CharField(max_length=255)
chat_id = pw.CharField(max_length=255)
user_id = pw.CharField(max_length=255)
timestamp = pw.BigIntegerField()

class Meta:
table_name = "chatidtag"

@migrator.create_model
class Document(pw.Model):
id = pw.AutoField()
collection_name = pw.CharField(max_length=255, unique=True)
name = pw.CharField(max_length=255, unique=True)
title = pw.TextField()
filename = pw.TextField()
content = pw.TextField(null=True)
user_id = pw.CharField(max_length=255)
timestamp = pw.BigIntegerField()

class Meta:
table_name = "document"

@migrator.create_model
class Modelfile(pw.Model):
id = pw.AutoField()
tag_name = pw.CharField(max_length=255, unique=True)
user_id = pw.CharField(max_length=255)
modelfile = pw.TextField()
timestamp = pw.BigIntegerField()

class Meta:
table_name = "modelfile"

@migrator.create_model
class Prompt(pw.Model):
id = pw.AutoField()
command = pw.CharField(max_length=255, unique=True)
user_id = pw.CharField(max_length=255)
title = pw.TextField()
content = pw.TextField()
timestamp = pw.BigIntegerField()

class Meta:
table_name = "prompt"

@migrator.create_model
class Tag(pw.Model):
id = pw.CharField(max_length=255, unique=True)
name = pw.CharField(max_length=255)
user_id = pw.CharField(max_length=255)
data = pw.TextField(null=True)

class Meta:
table_name = "tag"

@migrator.create_model
class User(pw.Model):
id = pw.CharField(max_length=255, unique=True)
name = pw.CharField(max_length=255)
email = pw.CharField(max_length=255)
role = pw.CharField(max_length=255)
profile_image_url = pw.TextField()
timestamp = pw.BigIntegerField()

class Meta:
table_name = "user"


def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""

Expand Down
53 changes: 53 additions & 0 deletions backend/apps/web/internal/migrations/005_add_updated_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""

if isinstance(database, pw.SqliteDatabase):
migrate_sqlite(migrator, database, fake=fake)
else:
migrate_external(migrator, database, fake=fake)


def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
# Adding fields created_at and updated_at to the 'chat' table
migrator.add_fields(
"chat",
Expand All @@ -60,9 +67,40 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
)


def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
# Adding fields created_at and updated_at to the 'chat' table
migrator.add_fields(
"chat",
created_at=pw.BigIntegerField(null=True), # Allow null for transition
updated_at=pw.BigIntegerField(null=True), # Allow null for transition
)

# Populate the new fields from an existing 'timestamp' field
migrator.sql(
"UPDATE chat SET created_at = timestamp, updated_at = timestamp WHERE timestamp IS NOT NULL"
)

# Now that the data has been copied, remove the original 'timestamp' field
migrator.remove_fields("chat", "timestamp")

# Update the fields to be not null now that they are populated
migrator.change_fields(
"chat",
created_at=pw.BigIntegerField(null=False),
updated_at=pw.BigIntegerField(null=False),
)


def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""

if isinstance(database, pw.SqliteDatabase):
rollback_sqlite(migrator, database, fake=fake)
else:
rollback_external(migrator, database, fake=fake)


def rollback_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
# Recreate the timestamp field initially allowing null values for safe transition
migrator.add_fields("chat", timestamp=pw.DateTimeField(null=True))

Expand All @@ -75,3 +113,18 @@ def rollback(migrator: Migrator, database: pw.Database, *, fake=False):

# Finally, alter the timestamp field to not allow nulls if that was the original setting
migrator.change_fields("chat", timestamp=pw.DateTimeField(null=False))


def rollback_external(migrator: Migrator, database: pw.Database, *, fake=False):
# Recreate the timestamp field initially allowing null values for safe transition
migrator.add_fields("chat", timestamp=pw.BigIntegerField(null=True))

# Copy the earliest created_at date back into the new timestamp field
# This assumes created_at was originally a copy of timestamp
migrator.sql("UPDATE chat SET timestamp = created_at")

# Remove the created_at and updated_at fields
migrator.remove_fields("chat", "created_at", "updated_at")

# Finally, alter the timestamp field to not allow nulls if that was the original setting
migrator.change_fields("chat", timestamp=pw.BigIntegerField(null=False))
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""Peewee migrations -- 006_migrate_timestamps_and_charfields.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""

from contextlib import suppress

import peewee as pw
from peewee_migrate import Migrator


with suppress(ImportError):
import playhouse.postgres_ext as pw_pext


def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""

# Alter the tables with timestamps
migrator.change_fields(
"chatidtag",
timestamp=pw.BigIntegerField(),
)
migrator.change_fields(
"document",
timestamp=pw.BigIntegerField(),
)
migrator.change_fields(
"modelfile",
timestamp=pw.BigIntegerField(),
)
migrator.change_fields(
"prompt",
timestamp=pw.BigIntegerField(),
)
migrator.change_fields(
"user",
timestamp=pw.BigIntegerField(),
)
# Alter the tables with varchar to text where necessary
migrator.change_fields(
"auth",
password=pw.TextField(),
)
migrator.change_fields(
"chat",
title=pw.TextField(),
)
migrator.change_fields(
"document",
title=pw.TextField(),
filename=pw.TextField(),
)
migrator.change_fields(
"prompt",
title=pw.TextField(),
)
migrator.change_fields(
"user",
profile_image_url=pw.TextField(),
)


def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""

if isinstance(database, pw.SqliteDatabase):
# Alter the tables with timestamps
migrator.change_fields(
"chatidtag",
timestamp=pw.DateField(),
)
migrator.change_fields(
"document",
timestamp=pw.DateField(),
)
migrator.change_fields(
"modelfile",
timestamp=pw.DateField(),
)
migrator.change_fields(
"prompt",
timestamp=pw.DateField(),
)
migrator.change_fields(
"user",
timestamp=pw.DateField(),
)
migrator.change_fields(
"auth",
password=pw.CharField(max_length=255),
)
migrator.change_fields(
"chat",
title=pw.CharField(),
)
migrator.change_fields(
"document",
title=pw.CharField(),
filename=pw.CharField(),
)
migrator.change_fields(
"prompt",
title=pw.CharField(),
)
migrator.change_fields(
"user",
profile_image_url=pw.CharField(),
)
2 changes: 1 addition & 1 deletion backend/apps/web/models/auths.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
class Auth(Model):
id = CharField(unique=True)
email = CharField()
password = CharField()
password = TextField()
active = BooleanField()

class Meta:
Expand Down
Loading

0 comments on commit e91a49c

Please sign in to comment.