Skip to content

Commit 3bb0c97

Browse files
committed
Do not subscribe to WS events inside Agent.
1 parent 7f758e3 commit 3bb0c97

File tree

2 files changed

+21
-112
lines changed

2 files changed

+21
-112
lines changed

agents-core/vision_agents/core/agents/agents.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -538,15 +538,6 @@ async def join(
538538
with self.span("edge.publish_tracks"):
539539
await self.edge.publish_tracks(audio_track, video_track)
540540

541-
connection._connection._coordinator_ws_client.on_wildcard(
542-
"*",
543-
lambda event_name, event: self.events.send(event),
544-
)
545-
546-
connection._connection._ws_client.on_wildcard(
547-
"*",
548-
lambda event_name, event: self.events.send(event),
549-
)
550541

551542
from .agent_session import AgentSessionContextManager
552543

plugins/getstream/vision_agents/plugins/getstream/stream_edge_transport.py

Lines changed: 21 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ def __init__(self, connection: ConnectionManager):
4343
def participants(self) -> ParticipantsState:
4444
return self._connection.participants_state
4545

46-
async def close(self):
46+
async def close(self, timeout: float = 2.0):
4747
try:
48-
await asyncio.wait_for(self._connection.leave(), timeout=2.0)
48+
await asyncio.wait_for(self._connection.leave(), timeout=timeout)
4949
except asyncio.TimeoutError:
5050
logger.warning("Connection leave timed out during close")
5151
except RuntimeError as e:
@@ -83,11 +83,19 @@ def __init__(self, **kwargs):
8383
# track_id -> (user_id, session_id, webrtc_type_string)
8484
self._pending_tracks: dict = {}
8585

86+
self._real_connection: Optional[ConnectionManager] = None
87+
8688
# Register event handlers
8789
self.events.subscribe(self._on_track_published)
8890
self.events.subscribe(self._on_track_removed)
8991
self.events.subscribe(self._on_call_ended)
9092

93+
@property
94+
def _connection(self) -> ConnectionManager:
95+
if self._real_connection is None:
96+
raise ValueError("Edge connection is not set")
97+
return self._real_connection
98+
9199
def _get_webrtc_kind(self, track_type_int: int) -> str:
92100
"""Get the expected WebRTC kind (audio/video) for a SFU track type."""
93101
# Map SFU track types to WebRTC kinds
@@ -105,96 +113,6 @@ def _get_webrtc_kind(self, track_type_int: int) -> str:
105113
# Default to video for unknown types
106114
return "video"
107115

108-
async def _subscribe_to_existing_tracks(
109-
self, connection: ConnectionManager
110-
) -> None:
111-
"""Subscribe to tracks from participants who joined before the agent."""
112-
from vision_agents.core.edge.sfu_events import Participant as SfuParticipant
113-
114-
participants = connection.participants_state.get_participants()
115-
subscription_manager = connection._subscription_manager
116-
tracks_to_subscribe = []
117-
118-
for participant in participants:
119-
if participant.user_id == self.agent_user_id:
120-
continue
121-
122-
for track_type_int in participant.published_tracks:
123-
# Create a mock event for the subscription manager
124-
class MockTrackPublishedEvent:
125-
def __init__(self, p, track_type):
126-
self.user_id = p.user_id
127-
self.session_id = p.session_id
128-
self.type = track_type
129-
self.participant = p
130-
131-
mock_event = MockTrackPublishedEvent(participant, track_type_int)
132-
133-
try:
134-
await subscription_manager.handle_track_published(mock_event)
135-
tracks_to_subscribe.append((participant, track_type_int))
136-
except Exception as e:
137-
logger.error(f"Failed to subscribe to existing track: {e}")
138-
139-
# Poll for WebRTC tracks to arrive after subscription
140-
for participant, track_type_int in tracks_to_subscribe:
141-
expected_kind = self._get_webrtc_kind(track_type_int)
142-
track_key = (
143-
participant.user_id,
144-
participant.session_id,
145-
track_type_int,
146-
)
147-
148-
if track_key in self._track_map:
149-
continue
150-
151-
# Poll for WebRTC track ID with timeout (same pattern as _on_track_published)
152-
track_id = None
153-
timeout = 10.0
154-
poll_interval = 0.01
155-
elapsed = 0.0
156-
157-
while elapsed < timeout:
158-
for tid, (pending_user, pending_session, pending_kind) in list(
159-
self._pending_tracks.items()
160-
):
161-
if (
162-
pending_user == participant.user_id
163-
and pending_session == participant.session_id
164-
and pending_kind == expected_kind
165-
):
166-
track_id = tid
167-
del self._pending_tracks[tid]
168-
break
169-
170-
if track_id:
171-
break
172-
173-
await asyncio.sleep(poll_interval)
174-
elapsed += poll_interval
175-
176-
if track_id:
177-
self._track_map[track_key] = {
178-
"track_id": track_id,
179-
"published": True,
180-
}
181-
sfu_participant = SfuParticipant.from_proto(participant)
182-
183-
self.events.send(
184-
events.TrackAddedEvent(
185-
plugin_name="getstream",
186-
track_id=track_id,
187-
track_type=track_type_int,
188-
user=sfu_participant,
189-
participant=sfu_participant,
190-
)
191-
)
192-
else:
193-
logger.warning(
194-
f"No pending track for existing participant: "
195-
f"user={participant.user_id}, type={TrackType.Name(track_type_int)}"
196-
)
197-
198116
async def _on_track_published(self, event: sfu_events.TrackPublishedEvent):
199117
"""Handle track published events from SFU - spawn TrackAddedEvent with correct type."""
200118
if not event.payload:
@@ -366,13 +284,10 @@ async def join(self, agent: "Agent", call: Call) -> StreamConnection:
366284
This function
367285
- initializes the chat channel
368286
- has the agent.agent_user join the call
369-
- connect incoming audio/video to the agent
287+
- connects incoming audio/video to the agent
370288
- connecting agent's outgoing audio/video to the call
371-
372-
TODO:
373-
- process track flow
374-
375289
"""
290+
376291
# Traditional mode - use WebRTC connection
377292
# Configure subscription for audio and video
378293
subscription_config = SubscriptionConfig(
@@ -401,13 +316,16 @@ async def on_audio_received(pcm: PcmData):
401316
)
402317
)
403318

404-
await (
405-
connection.__aenter__()
406-
) # TODO: weird API? there should be a manual version
407-
self._connection = connection
319+
# Re-emit certain events from the underlying RTC stack
320+
# for the Agent to subscribe.
321+
connection.on("participant_joined", self.events.send)
322+
connection.on("participant_left", self.events.send)
323+
connection.on("track_published", self.events.send)
324+
connection.on("track_unpublished", self.events.send)
408325

409-
# Subscribe to tracks from participants who joined before the agent
410-
await self._subscribe_to_existing_tracks(connection)
326+
# Start the connection
327+
await connection.__aenter__()
328+
self._real_connection = connection
411329

412330
standardize_connection = StreamConnection(connection)
413331
return standardize_connection

0 commit comments

Comments
 (0)