Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backend/director/core/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,10 @@ def delete(self):
"""Delete the session from the database."""
return self.db.delete_session(self.session_id)

def update(self, **kwargs) -> bool:
"""Update the session in the database."""
return self.db.update_session(self.session_id, **kwargs)

def emit_event(self, event: BaseEvent, namespace="/chat"):
"""Emits a structured WebSocket event to notify all clients about updates."""

Expand Down
22 changes: 22 additions & 0 deletions backend/director/db/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,28 @@ def add_or_update_context_msg(
"""Update context messages for a session."""
pass

@abstractmethod
def update_session(self, session_id: str, **kwargs) -> bool:
"""Update a session in the database."""
pass

@abstractmethod
def delete_session(self, session_id: str) -> tuple[bool, list]:
"""Delete a session from the database.
:return: (success, failed_components)
"""
pass

@abstractmethod
def make_session_public(self, session_id: str, is_public: bool) -> bool:
"""Make a session public or private."""
pass

@abstractmethod
def get_public_session(self, session_id: str) -> dict:
"""Get a public session by session_id."""
pass

@abstractmethod
def health_check(self) -> bool:
"""Check if the database is healthy."""
Expand Down
93 changes: 89 additions & 4 deletions backend/director/db/postgres/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,13 @@ def __init__(self):
port=os.getenv("POSTGRES_PORT", "5432"),
)
self.cursor = self.conn.cursor(cursor_factory=RealDictCursor)

initialize_postgres()
def create_session(
self,
session_id: str,
video_id: str,
collection_id: str,
name: str = None,
created_at: int = None,
updated_at: int = None,
metadata: dict = {},
Expand All @@ -47,14 +48,15 @@ def create_session(

self.cursor.execute(
"""
INSERT INTO sessions (session_id, video_id, collection_id, created_at, updated_at, metadata)
VALUES (%s, %s, %s, %s, %s, %s)
INSERT INTO sessions (session_id, video_id, collection_id, name, created_at, updated_at, metadata)
VALUES (%s, %s, %s, %s, %s, %s, %s)
ON CONFLICT (session_id) DO NOTHING
""",
(
session_id,
video_id,
collection_id,
name,
created_at,
updated_at,
json.dumps(metadata),
Expand Down Expand Up @@ -195,7 +197,7 @@ def delete_context(self, session_id: str) -> bool:
self.conn.commit()
return self.cursor.rowcount > 0

def delete_session(self, session_id: str) -> bool:
def delete_session(self, session_id: str) -> tuple[bool, list]:
failed_components = []
if not self.delete_conversation(session_id):
failed_components.append("conversation")
Expand All @@ -210,6 +212,89 @@ def delete_session(self, session_id: str) -> bool:
success = len(failed_components) < 3
return success, failed_components

def update_session(self, session_id: str, **kwargs) -> bool:
"""Update a session in the database."""
try:
if not kwargs:
return False

allowed_fields = {"name", "video_id", "collection_id", "metadata"}
update_fields = []
update_values = []

for key, value in kwargs.items():
if key not in allowed_fields:
continue
if key == "metadata" and not isinstance(value, str):
value = json.dumps(value)
update_fields.append(f"{key} = %s")
update_values.append(value)

if not update_fields:
return False

update_fields.append("updated_at = %s")
update_values.append(int(time.time()))

update_values.extend([session_id])

query = f"""
UPDATE sessions
SET {', '.join(update_fields)}
WHERE session_id = %s
"""

self.cursor.execute(query, update_values)
self.conn.commit()
return self.cursor.rowcount > 0

except Exception:
logger.exception(f"Error updating session {session_id}")
return False

def make_session_public(self, session_id: str, is_public: bool) -> bool:
"""Make a session public or private."""
try:
query = """
UPDATE sessions
SET is_public = %s, updated_at = %s
WHERE session_id = %s
"""
current_time = int(time.time())
self.cursor.execute(query, (is_public, current_time, session_id))
self.conn.commit()
return self.cursor.rowcount > 0
except Exception as e:
logger.exception(f"Error making session public/private: {e}")
return False

def get_public_session(self, session_id: str) -> dict:
"""Get a public session by session_id."""
try:
query = """
SELECT session_id, video_id, collection_id, name, created_at, updated_at, metadata, is_public
FROM sessions
WHERE session_id = %s AND is_public = TRUE
"""
self.cursor.execute(query, (session_id,))
row = self.cursor.fetchone()
if row:
session = {
"session_id": row["session_id"],
"video_id": row["video_id"],
"collection_id": row["collection_id"],
"name": row["name"],
"created_at": row["created_at"],
"updated_at": row["updated_at"],
"metadata": row["metadata"] if row["metadata"] else {},
"is_public": row["is_public"]
}
return session
return {}
except Exception as e:
logger.exception(f"Error getting public session: {e}")
return {}

def health_check(self) -> bool:
try:
query = """
Expand Down
4 changes: 4 additions & 0 deletions backend/director/db/postgres/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
session_id TEXT PRIMARY KEY,
video_id TEXT,
collection_id TEXT,
name TEXT,
is_public BOOLEAN DEFAULT FALSE,
created_at BIGINT,
updated_at BIGINT,
metadata JSONB
Expand Down Expand Up @@ -68,6 +70,8 @@ def initialize_postgres():
cursor.execute(CREATE_SESSIONS_TABLE)
cursor.execute(CREATE_CONVERSATIONS_TABLE)
cursor.execute(CREATE_CONTEXT_MESSAGES_TABLE)
cursor.execute("ALTER TABLE sessions ADD COLUMN IF NOT EXISTS name TEXT")
cursor.execute("ALTER TABLE sessions ADD COLUMN IF NOT EXISTS is_public BOOLEAN DEFAULT FALSE")
conn.commit()
logger.info("PostgreSQL tables created successfully")
except Exception as e:
Expand Down
100 changes: 95 additions & 5 deletions backend/director/db/sqlite/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(self, db_path: str = None):
self.db_path = os.getenv("SQLITE_DB_PATH", "director.db")
else:
self.db_path = db_path
initialize_sqlite(self.db_path)
self.conn = sqlite3.connect(self.db_path, check_same_thread=True)
self.conn.row_factory = sqlite3.Row
self.cursor = self.conn.cursor()
Expand All @@ -33,6 +34,7 @@ def create_session(
session_id: str,
video_id: str,
collection_id: str,
name: str = None,
created_at: int = None,
updated_at: int = None,
metadata: dict = {},
Expand All @@ -52,13 +54,14 @@ def create_session(

self.cursor.execute(
"""
INSERT OR IGNORE INTO sessions (session_id, video_id, collection_id, created_at, updated_at, metadata)
VALUES (?, ?, ?, ?, ?, ?)
INSERT OR IGNORE INTO sessions (session_id, video_id, collection_id, name, created_at, updated_at, metadata)
VALUES (?, ?, ?, ?, ?, ?, ?)
""",
(
session_id,
video_id,
collection_id,
name,
created_at,
updated_at,
json.dumps(metadata),
Expand Down Expand Up @@ -153,8 +156,7 @@ def add_or_update_msg_to_conv(

def get_conversations(self, session_id: str) -> list:
self.cursor.execute(
"SELECT * FROM conversations WHERE session_id = ? ORDER BY created_at ASC",
(session_id,),
"SELECT * FROM conversations WHERE session_id = ?", (session_id,)
)
rows = self.cursor.fetchall()
conversations = []
Expand Down Expand Up @@ -241,7 +243,7 @@ def delete_context(self, session_id: str) -> bool:
self.conn.commit()
return self.cursor.rowcount > 0

def delete_session(self, session_id: str) -> bool:
def delete_session(self, session_id: str) -> tuple[bool, list]:
"""Delete a session and all its associated data.

:param str session_id: Unique session ID.
Expand All @@ -259,6 +261,94 @@ def delete_session(self, session_id: str) -> bool:
success = len(failed_components) < 3
return success, failed_components

def update_session(self, session_id: str, **kwargs) -> bool:
"""Update a session in the database.

:param session_id: Unique session ID.
:param kwargs: Fields to update.
:return: True if update was successful, False otherwise.
"""
try:
if not kwargs:
return False

allowed_fields = {"name", "video_id", "collection_id", "metadata"}
update_fields = []
update_values = []

for key, value in kwargs.items():
if key not in allowed_fields:
continue
if key == "metadata" and not isinstance(value, str):
value = json.dumps(value)
update_fields.append(f"{key} = ?")
update_values.append(value)

if not update_fields:
return False

update_fields.append("updated_at = ?")
update_values.append(int(time.time()))

update_values.extend([session_id])

query = f"""
UPDATE sessions
SET {', '.join(update_fields)}
WHERE session_id = ?
"""

self.cursor.execute(query, update_values)
self.conn.commit()
return self.cursor.rowcount > 0

except Exception:
logger.exception(f"Error updating session {session_id}")
return False

def make_session_public(self, session_id: str, is_public: bool) -> bool:
"""Make a session public or private."""
try:
query = """
UPDATE sessions
SET is_public = ?, updated_at = ?
WHERE session_id = ?
"""
current_time = int(time.time())
self.cursor.execute(query, (is_public, current_time, session_id))
self.conn.commit()
return self.cursor.rowcount > 0
except Exception:
logger.exception("Error making session public/private")
return False

def get_public_session(self, session_id: str) -> dict:
"""Get a public session by session_id."""
try:
query = """
SELECT session_id, video_id, collection_id, name, created_at, updated_at, metadata, is_public
FROM sessions
WHERE session_id = ? AND is_public = TRUE
"""
self.cursor.execute(query, (session_id,))
row = self.cursor.fetchone()
if row:
session = {
"session_id": row[0],
"video_id": row[1],
"collection_id": row[2],
"name": row[3],
"created_at": row[4],
"updated_at": row[5],
"metadata": json.loads(row[6]) if row[6] else {},
"is_public": row[7]
}
return session
return {}
except Exception as e:
logger.exception(f"Error getting public session: {e}")
return {}

def health_check(self) -> bool:
"""Check if the SQLite database is healthy and the necessary tables exist. If not, create them."""
try:
Expand Down
12 changes: 12 additions & 0 deletions backend/director/db/sqlite/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
session_id TEXT PRIMARY KEY,
video_id TEXT,
collection_id TEXT,
name TEXT,
is_public BOOLEAN DEFAULT FALSE,
created_at INTEGER,
updated_at INTEGER,
metadata JSON
Expand Down Expand Up @@ -49,10 +51,20 @@ def initialize_sqlite(db_name="director.db"):
conn = sqlite3.connect(db_name)
cursor = conn.cursor()

# Create base tables
cursor.execute(CREATE_SESSIONS_TABLE)
cursor.execute(CREATE_CONVERSATIONS_TABLE)
cursor.execute(CREATE_CONTEXT_MESSAGES_TABLE)

cursor.execute("PRAGMA table_info(sessions)")
columns = [col[1] for col in cursor.fetchall()]

if "name" not in columns:
cursor.execute("ALTER TABLE sessions ADD COLUMN name TEXT")

if "is_public" not in columns:
cursor.execute("ALTER TABLE sessions ADD COLUMN is_public BOOLEAN DEFAULT FALSE")

conn.commit()
conn.close()

Expand Down
Loading