Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions hud/tools/coding/bash.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,22 @@ class ClaudeBashSession:
"""A persistent bash shell session for Claude's bash tool.

Uses readuntil-based output capture, which is simpler than ShellTool's
polling approach but doesn't support dynamic timeouts.
polling approach.
"""

_started: bool
_process: asyncio.subprocess.Process
_timed_out: bool

command: str = "/bin/bash"
_timeout: float = 120.0 # seconds
_sentinel: str = "<<exit>>"

def __init__(self) -> None:
DEFAULT_TIMEOUT: float = 120.0 # seconds

def __init__(self, timeout: float = DEFAULT_TIMEOUT) -> None:
self._started = False
self._timed_out = False
self._timeout = timeout

async def start(self) -> None:
"""Start the bash session."""
Expand Down Expand Up @@ -77,7 +79,9 @@ async def run(self, command: str) -> ContentResult:
)
if self._timed_out:
raise ToolError(
f"timed out: bash did not return in {self._timeout} seconds and must be restarted",
f"Bash session timed out waiting for output after {self._timeout}s. "
"Background processes may still be running. "
"Use restart=true to get a new session.",
) from None

if self._process.stdin is None:
Expand Down Expand Up @@ -113,7 +117,9 @@ async def run(self, command: str) -> ContentResult:
except (TimeoutError, asyncio.LimitOverrunError):
self._timed_out = True
raise ToolError(
f"timed out: bash did not return in {self._timeout} seconds and must be restarted",
f"Bash session timed out waiting for output after {self._timeout}s. "
"Background processes may still be running. "
"Use restart=true to get a new session.",
) from None

# Attempt non-blocking stderr fetch (may return empty)
Expand Down Expand Up @@ -158,19 +164,27 @@ class BashTool(BaseTool):
),
}

def __init__(self, session: ClaudeBashSession | None = None) -> None:
def __init__(
self,
session: ClaudeBashSession | None = None,
timeout: float = ClaudeBashSession.DEFAULT_TIMEOUT,
) -> None:
"""Initialize BashTool with an optional session.

Args:
session: Optional pre-configured bash session. If not provided,
a new session will be created on first use.
timeout: Timeout in seconds for command execution. Defaults to 120s.
If a pre-configured session is provided, the timeout is
derived from that session instead.
"""
super().__init__(
env=session,
name="bash",
title="Bash Shell",
description="Execute bash commands in a persistent shell session",
)
self._timeout = session._timeout if session is not None else timeout

@property
def session(self) -> ClaudeBashSession | None:
Expand All @@ -195,15 +209,14 @@ async def __call__(
List of MCP ContentBlocks with the result
"""
if restart:
session_cls = type(self.session) if self.session else ClaudeBashSession
if self.session:
self.session.stop()
self.session = session_cls()
self.session = ClaudeBashSession(timeout=self._timeout)
await self.session.start()
return ContentResult(output="Bash session restarted.").to_content_blocks()

if self.session is None:
self.session = ClaudeBashSession()
self.session = ClaudeBashSession(timeout=self._timeout)

if not self.session._started:
await self.session.start()
Expand Down
33 changes: 13 additions & 20 deletions hud/tools/coding/tests/test_bash.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,40 +211,33 @@ async def test_call_restart_with_existing_session(self):
"""Test restarting the tool when there's an existing session calls stop()."""
tool = BashTool()

# Track calls across instances of our fake session class
stop_called = []
start_called = []

class FakeSession:
"""Fake session that tracks stop/start calls."""

async def start(self) -> None:
start_called.append(True)

def stop(self) -> None:
stop_called.append(True)

# Set up existing session
old_session = FakeSession()
# Set up existing session with a mock
old_session = MagicMock()
old_session.stop = MagicMock()
tool.session = old_session # type: ignore[assignment]

result = await tool(restart=True)
# Mock the new session that will be created
new_session = MagicMock()
new_session.start = AsyncMock()

with patch("hud.tools.coding.bash.ClaudeBashSession", return_value=new_session):
result = await tool(restart=True)

# Verify old session was stopped
assert len(stop_called) == 1, "stop() should be called on old session"
old_session.stop.assert_called_once()

# Verify new session was started
assert len(start_called) == 1, "start() should be called on new session"
new_session.start.assert_called_once()

# Verify result
assert isinstance(result, list)
assert len(result) == 1
assert isinstance(result[0], TextContent)
assert result[0].text == "Bash session restarted."

# Verify new session is a FakeSession (type preserved)
assert isinstance(tool.session, FakeSession)
# Verify new session replaced the old one
assert tool.session is not old_session
assert tool.session is new_session

@pytest.mark.asyncio
async def test_call_no_command_error(self):
Expand Down
30 changes: 28 additions & 2 deletions hud/tools/coding/tests/test_bash_extended.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,34 @@ async def test_session_run_with_asyncio_timeout(self):
with pytest.raises(ToolError) as exc_info:
await session.run("slow command")

assert "timed out" in str(exc_info.value)
assert "120.0 seconds" in str(exc_info.value)
assert "timed out waiting for output" in str(exc_info.value)
assert "120.0s" in str(exc_info.value)
assert "Background processes may still be running" in str(exc_info.value)
assert "restart=true" in str(exc_info.value)

@pytest.mark.asyncio
async def test_session_run_with_custom_timeout(self):
"""Test that a custom timeout value is used and reported in the error."""
session = _BashSession(timeout=1.0)
assert session._timeout == 1.0

session._started = True

mock_process = MagicMock()
mock_process.returncode = None
mock_process.stdin = MagicMock()
mock_process.stdin.write = MagicMock()
mock_process.stdin.drain = AsyncMock()
mock_process.stdout = MagicMock()
mock_process.stdout.readuntil = AsyncMock(side_effect=TimeoutError())

session._process = mock_process

with pytest.raises(ToolError) as exc_info:
await session.run("sleep 5")

assert "1.0s" in str(exc_info.value)
assert "120" not in str(exc_info.value)

@pytest.mark.asyncio
async def test_session_run_with_stdout_exception(self):
Expand Down
Loading