Skip to content

Commit 4ab5908

Browse files
author
Murat Kaan Meral
committed
code improvements for styling and validation
1 parent 1490365 commit 4ab5908

File tree

4 files changed

+198
-83
lines changed

4 files changed

+198
-83
lines changed

src/strands_agents_builder/strands.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from strands_agents_builder.utils import model_utils
1717
from strands_agents_builder.utils.kb_utils import load_system_prompt, store_conversation_in_kb
1818
from strands_agents_builder.utils.session_utils import (
19+
console,
1920
display_agent_history,
2021
handle_session_commands,
2122
list_sessions_command,
@@ -50,7 +51,7 @@ def handle_shell_command(agent: Agent, command: str, user_input: str) -> None:
5051
non_interactive_mode=True,
5152
)
5253
except Exception as e:
53-
print(f"Error: {str(e)}")
54+
console.print(f"[red]Error: {str(e)}[/red]")
5455

5556

5657
def execute_interactive_mode(
@@ -120,7 +121,7 @@ def execute_interactive_mode(
120121
break
121122
except Exception as e:
122123
callback_handler(force_stop=True) # Stop spinners
123-
print(f"Error: {str(e)}")
124+
console.print(f"[red]Error: {str(e)}[/red]")
124125

125126

126127
def main():

src/strands_agents_builder/utils/session_utils.py

Lines changed: 103 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""
55

66
import datetime
7+
import logging
78
import time
89
import uuid
910
from pathlib import Path
@@ -19,6 +20,42 @@
1920
# Create console for rich formatting
2021
console = Console()
2122

23+
# Constants
24+
SESSION_PREFIX = "session_"
25+
DEFAULT_DISPLAY_LIMIT = 10
26+
27+
# Set up logging
28+
logger = logging.getLogger(__name__)
29+
30+
31+
def validate_session_id(session_id: str) -> bool:
32+
"""Validate that a session ID is safe to use as a directory name."""
33+
if not session_id:
34+
return False
35+
36+
# Check for basic safety - no path separators, no hidden files, reasonable length
37+
if any(char in session_id for char in ["/", "\\", "..", "\0"]):
38+
return False
39+
40+
if session_id.startswith(".") or len(session_id) > 255:
41+
return False
42+
43+
return True
44+
45+
46+
def validate_session_path(path: str) -> bool:
47+
"""Validate that a session path is safe to use."""
48+
if not path:
49+
return False
50+
51+
try:
52+
# Try to create a Path object and check if it's absolute or relative
53+
path_obj = Path(path)
54+
# Basic validation - path should be reasonable
55+
return len(str(path_obj)) < 4096 # Reasonable path length limit
56+
except (ValueError, OSError):
57+
return False
58+
2259

2360
def generate_session_id() -> str:
2461
"""Generate a unique session ID based on timestamp and UUID."""
@@ -44,11 +81,14 @@ def create_session_manager(
4481
session_id: Optional[str] = None, base_path: Optional[str] = None
4582
) -> Optional[FileSessionManager]:
4683
"""Create a FileSessionManager with the given or generated session ID. Returns None if no base_path."""
47-
if not base_path:
84+
if not base_path or not validate_session_path(base_path):
4885
return None
4986

5087
if session_id is None:
5188
session_id = generate_session_id()
89+
elif not validate_session_id(session_id):
90+
logger.warning(f"Invalid session ID provided: {session_id}")
91+
return None
5292

5393
# Create the sessions directory since we're actually creating a session manager
5494
sessions_dir = get_sessions_directory(base_path, create=True)
@@ -57,68 +97,81 @@ def create_session_manager(
5797

5898
def list_available_sessions(base_path: Optional[str] = None) -> list[str]:
5999
"""List all available session IDs in the sessions directory."""
60-
if not base_path:
100+
if not base_path or not validate_session_path(base_path):
61101
return []
62102

63103
# Don't create directory, just check if it exists
64104
sessions_dir = get_sessions_directory(base_path, create=False)
65105
session_ids = []
66106

67107
if sessions_dir and sessions_dir.exists():
68-
for session_dir in sessions_dir.iterdir():
69-
if session_dir.is_dir() and session_dir.name.startswith("session_"):
70-
# Extract session ID from directory name (remove "session_" prefix)
71-
session_id = session_dir.name[8:] # len("session_") = 8
72-
session_ids.append(session_id)
108+
try:
109+
for session_dir in sessions_dir.iterdir():
110+
if session_dir.is_dir() and session_dir.name.startswith(SESSION_PREFIX):
111+
# Extract session ID from directory name (remove "session_" prefix)
112+
session_id = session_dir.name[len(SESSION_PREFIX) :]
113+
if validate_session_id(session_id):
114+
session_ids.append(session_id)
115+
except (OSError, PermissionError) as e:
116+
logger.warning(f"Failed to list sessions in {base_path}: {e}")
73117

74118
return sorted(session_ids)
75119

76120

77121
def session_exists(session_id: str, base_path: Optional[str] = None) -> bool:
78122
"""Check if a session exists."""
79-
if not base_path:
123+
if not base_path or not validate_session_path(base_path) or not validate_session_id(session_id):
80124
return False
81125

82126
# Don't create directory, just check if session exists
83127
sessions_dir = get_sessions_directory(base_path, create=False)
84128
if not sessions_dir:
85129
return False
86130

87-
session_dir = sessions_dir / f"session_{session_id}"
131+
session_dir = sessions_dir / f"{SESSION_PREFIX}{session_id}"
88132
return session_dir.exists() and (session_dir / "session.json").exists()
89133

90134

91135
def get_session_info(session_id: str, base_path: Optional[str] = None) -> Optional[dict]:
92136
"""Get basic information about a session."""
93-
if not base_path or not session_exists(session_id, base_path):
137+
if not base_path or not validate_session_path(base_path) or not validate_session_id(session_id):
138+
return None
139+
140+
if not session_exists(session_id, base_path):
94141
return None
95142

96143
# Don't create directory, just get the path
97144
sessions_dir = get_sessions_directory(base_path, create=False)
98145
if not sessions_dir:
99146
return None
100147

101-
session_dir = sessions_dir / f"session_{session_id}"
102-
103-
# Get creation time from directory
104-
created_at = session_dir.stat().st_ctime
148+
session_dir = sessions_dir / f"{SESSION_PREFIX}{session_id}"
105149

106-
# Count messages across all agents
107-
total_messages = 0
108-
agents_dir = session_dir / "agents"
109-
if agents_dir.exists():
110-
for agent_dir in agents_dir.iterdir():
111-
if agent_dir.is_dir():
112-
messages_dir = agent_dir / "messages"
113-
if messages_dir.exists():
114-
total_messages += len([f for f in messages_dir.iterdir() if f.is_file() and f.suffix == ".json"])
115-
116-
return {
117-
"session_id": session_id,
118-
"created_at": created_at,
119-
"total_messages": total_messages,
120-
"path": str(session_dir),
121-
}
150+
try:
151+
# Get creation time from directory
152+
created_at = session_dir.stat().st_ctime
153+
154+
# Count messages across all agents
155+
total_messages = 0
156+
agents_dir = session_dir / "agents"
157+
if agents_dir.exists():
158+
for agent_dir in agents_dir.iterdir():
159+
if agent_dir.is_dir():
160+
messages_dir = agent_dir / "messages"
161+
if messages_dir.exists():
162+
total_messages += len(
163+
[f for f in messages_dir.iterdir() if f.is_file() and f.suffix == ".json"]
164+
)
165+
166+
return {
167+
"session_id": session_id,
168+
"created_at": created_at,
169+
"total_messages": total_messages,
170+
"path": str(session_dir),
171+
}
172+
except (OSError, PermissionError) as e:
173+
logger.warning(f"Failed to get session info for {session_id}: {e}")
174+
return None
122175

123176

124177
def list_sessions_command(session_base_path: Optional[str]) -> None:
@@ -146,8 +199,8 @@ def display_agent_history(agent, session_id: str) -> None:
146199
"""Display conversation history from an agent's loaded messages."""
147200
try:
148201
if agent.messages and len(agent.messages) > 0:
149-
# Display last 10 messages (5 pairs) completely
150-
display_limit = 10
202+
# Display last messages completely
203+
display_limit = DEFAULT_DISPLAY_LIMIT
151204

152205
# Create header message
153206
header_text = f"Resuming session: {session_id}"
@@ -192,14 +245,15 @@ def display_agent_history(agent, session_id: str) -> None:
192245
print(f"{Fore.WHITE}{content}{Style.RESET_ALL}")
193246
print() # Empty line after assistant message
194247

195-
except Exception:
196-
# If we can't load history, just continue silently
197-
pass
248+
except Exception as e:
249+
# If we can't load history, log the error but continue
250+
logger.warning(f"Failed to display agent history for session {session_id}: {e}")
251+
console.print("[yellow]Warning: Could not load session history[/yellow]")
198252

199253

200254
def setup_session_management(
201255
session_id: Optional[str], session_base_path: Optional[str]
202-
) -> Tuple[Optional[object], Optional[str], bool]:
256+
) -> Tuple[Optional[FileSessionManager], Optional[str], bool]:
203257
"""Set up session management if enabled. Returns (session_manager, session_id, is_resuming)."""
204258
session_manager = None
205259
resolved_session_id = None
@@ -228,32 +282,34 @@ def handle_session_commands(command: str, session_id: Optional[str], session_bas
228282
info = get_session_info(session_id, session_base_path)
229283
if info:
230284
created = datetime.datetime.fromtimestamp(info["created_at"]).strftime("%Y-%m-%d %H:%M:%S")
231-
print(f"Session ID: {info['session_id']}")
232-
print(f"Created: {created}")
233-
print(f"Total messages: {info['total_messages']}")
285+
console.print(f"[bold cyan]Session ID:[/bold cyan] {info['session_id']}")
286+
console.print(f"[bold cyan]Created:[/bold cyan] {created}")
287+
console.print(f"[bold cyan]Total messages:[/bold cyan] {info['total_messages']}")
234288
return True
235289

236290
elif command == "session list" and session_base_path:
237291
sessions = list_available_sessions(session_base_path)
238292
if not sessions:
239-
print("No sessions found.")
293+
console.print("[yellow]No sessions found.[/yellow]")
240294
else:
241-
print("Available sessions:")
295+
console.print("[bold cyan]Available sessions:[/bold cyan]")
242296
for sid in sessions:
243297
info = get_session_info(sid, session_base_path)
244298
if info:
245299
created = datetime.datetime.fromtimestamp(info["created_at"]).strftime("%Y-%m-%d %H:%M:%S")
246-
current = " (current)" if sid == session_id else ""
247-
print(f" {sid} (created: {created}, messages: {info['total_messages']}){current}")
300+
current = " [dim](current)[/dim]" if sid == session_id else ""
301+
console.print(
302+
f" [green]{sid}[/green] (created: {created}, messages: {info['total_messages']}){current}"
303+
)
248304
return True
249305

250306
elif command.startswith("session "):
251307
if session_base_path:
252-
print("Available session commands:")
253-
print(" !session info - Show current session details")
254-
print(" !session list - List all available sessions")
308+
console.print("[bold cyan]Available session commands:[/bold cyan]")
309+
console.print(" [green]!session info[/green] - Show current session details")
310+
console.print(" [green]!session list[/green] - List all available sessions")
255311
else:
256-
print("Error: Session management not enabled.")
312+
console.print("[red]Error: Session management not enabled.[/red]")
257313
return True
258314

259315
return False

tests/test_strands.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -172,11 +172,11 @@ def test_eof_error_exception(self, mock_goodbye, mock_agent, mock_input):
172172
# Verify goodbye message was called
173173
mock_goodbye.assert_called_once()
174174

175-
@mock.patch("builtins.print")
175+
@mock.patch("strands_agents_builder.utils.session_utils.console.print")
176176
@mock.patch.object(strands, "get_user_input")
177177
@mock.patch.object(strands, "Agent")
178178
@mock.patch.object(strands, "callback_handler")
179-
def test_general_exception_handling(self, mock_callback_handler, mock_agent, mock_input, mock_print):
179+
def test_general_exception_handling(self, mock_callback_handler, mock_agent, mock_input, mock_console_print):
180180
"""Test handling of general exceptions in interactive mode"""
181181
# Setup mocks
182182
mock_agent_instance = mock.MagicMock()
@@ -194,7 +194,7 @@ def test_general_exception_handling(self, mock_callback_handler, mock_agent, moc
194194
strands.main()
195195

196196
# Verify error was called
197-
mock_print.assert_any_call("Error: Test error")
197+
mock_console_print.assert_any_call("[red]Error: Test error[/red]")
198198

199199
# Verify callback_handler was called to stop spinners
200200
mock_callback_handler.assert_called_once_with(force_stop=True)
@@ -314,9 +314,16 @@ def test_general_exception(self, mock_agent, mock_bedrock, mock_load_prompt, mon
314314
class TestShellCommandError:
315315
"""Test shell command error handling"""
316316

317-
@mock.patch("builtins.print")
317+
@mock.patch("strands_agents_builder.utils.session_utils.console.print")
318318
def test_shell_command_exception(
319-
self, mock_print, mock_agent, mock_bedrock, mock_load_prompt, mock_user_input, mock_welcome_message, monkeypatch
319+
self,
320+
mock_console_print,
321+
mock_agent,
322+
mock_bedrock,
323+
mock_load_prompt,
324+
mock_user_input,
325+
mock_welcome_message,
326+
monkeypatch,
320327
):
321328
"""Test handling exceptions when executing shell commands"""
322329
# Setup mocks
@@ -334,7 +341,7 @@ def test_shell_command_exception(
334341
strands.main()
335342

336343
# Verify error was called
337-
mock_print.assert_any_call("Error: Shell command failed")
344+
mock_console_print.assert_any_call("[red]Error: Shell command failed[/red]")
338345

339346

340347
class TestKnowledgeBaseIntegration:
@@ -573,14 +580,14 @@ def test_agent_creation_without_session_manager(self, mock_create_manager, mock_
573580
call_kwargs = mock_agent_class.call_args[1]
574581
assert "session_manager" not in call_kwargs or call_kwargs.get("session_manager") is None
575582

576-
@mock.patch("builtins.print")
583+
@mock.patch("strands_agents_builder.utils.session_utils.console.print")
577584
@mock.patch("strands_agents_builder.utils.session_utils.create_session_manager")
578585
@mock.patch("strands_agents_builder.utils.session_utils.get_session_info")
579586
def test_session_commands_in_interactive_mode(
580587
self,
581588
mock_get_info,
582589
mock_create_manager,
583-
mock_print,
590+
mock_console_print,
584591
mock_agent,
585592
mock_bedrock,
586593
mock_load_prompt,

0 commit comments

Comments
 (0)