@@ -43,9 +43,9 @@ def __init__(self, connection: ConnectionManager):
4343 def participants (self ) -> ParticipantsState :
4444 return self ._connection .participants_state
4545
46- async def close (self ):
46+ async def close (self , timeout : float = 2.0 ):
4747 try :
48- await asyncio .wait_for (self ._connection .leave (), timeout = 2.0 )
48+ await asyncio .wait_for (self ._connection .leave (), timeout = timeout )
4949 except asyncio .TimeoutError :
5050 logger .warning ("Connection leave timed out during close" )
5151 except RuntimeError as e :
@@ -83,11 +83,19 @@ def __init__(self, **kwargs):
8383 # track_id -> (user_id, session_id, webrtc_type_string)
8484 self ._pending_tracks : dict = {}
8585
86+ self ._real_connection : Optional [ConnectionManager ] = None
87+
8688 # Register event handlers
8789 self .events .subscribe (self ._on_track_published )
8890 self .events .subscribe (self ._on_track_removed )
8991 self .events .subscribe (self ._on_call_ended )
9092
93+ @property
94+ def _connection (self ) -> ConnectionManager :
95+ if self ._real_connection is None :
96+ raise ValueError ("Edge connection is not set" )
97+ return self ._real_connection
98+
9199 def _get_webrtc_kind (self , track_type_int : int ) -> str :
92100 """Get the expected WebRTC kind (audio/video) for a SFU track type."""
93101 # Map SFU track types to WebRTC kinds
@@ -105,96 +113,6 @@ def _get_webrtc_kind(self, track_type_int: int) -> str:
105113 # Default to video for unknown types
106114 return "video"
107115
108- async def _subscribe_to_existing_tracks (
109- self , connection : ConnectionManager
110- ) -> None :
111- """Subscribe to tracks from participants who joined before the agent."""
112- from vision_agents .core .edge .sfu_events import Participant as SfuParticipant
113-
114- participants = connection .participants_state .get_participants ()
115- subscription_manager = connection ._subscription_manager
116- tracks_to_subscribe = []
117-
118- for participant in participants :
119- if participant .user_id == self .agent_user_id :
120- continue
121-
122- for track_type_int in participant .published_tracks :
123- # Create a mock event for the subscription manager
124- class MockTrackPublishedEvent :
125- def __init__ (self , p , track_type ):
126- self .user_id = p .user_id
127- self .session_id = p .session_id
128- self .type = track_type
129- self .participant = p
130-
131- mock_event = MockTrackPublishedEvent (participant , track_type_int )
132-
133- try :
134- await subscription_manager .handle_track_published (mock_event )
135- tracks_to_subscribe .append ((participant , track_type_int ))
136- except Exception as e :
137- logger .error (f"Failed to subscribe to existing track: { e } " )
138-
139- # Poll for WebRTC tracks to arrive after subscription
140- for participant , track_type_int in tracks_to_subscribe :
141- expected_kind = self ._get_webrtc_kind (track_type_int )
142- track_key = (
143- participant .user_id ,
144- participant .session_id ,
145- track_type_int ,
146- )
147-
148- if track_key in self ._track_map :
149- continue
150-
151- # Poll for WebRTC track ID with timeout (same pattern as _on_track_published)
152- track_id = None
153- timeout = 10.0
154- poll_interval = 0.01
155- elapsed = 0.0
156-
157- while elapsed < timeout :
158- for tid , (pending_user , pending_session , pending_kind ) in list (
159- self ._pending_tracks .items ()
160- ):
161- if (
162- pending_user == participant .user_id
163- and pending_session == participant .session_id
164- and pending_kind == expected_kind
165- ):
166- track_id = tid
167- del self ._pending_tracks [tid ]
168- break
169-
170- if track_id :
171- break
172-
173- await asyncio .sleep (poll_interval )
174- elapsed += poll_interval
175-
176- if track_id :
177- self ._track_map [track_key ] = {
178- "track_id" : track_id ,
179- "published" : True ,
180- }
181- sfu_participant = SfuParticipant .from_proto (participant )
182-
183- self .events .send (
184- events .TrackAddedEvent (
185- plugin_name = "getstream" ,
186- track_id = track_id ,
187- track_type = track_type_int ,
188- user = sfu_participant ,
189- participant = sfu_participant ,
190- )
191- )
192- else :
193- logger .warning (
194- f"No pending track for existing participant: "
195- f"user={ participant .user_id } , type={ TrackType .Name (track_type_int )} "
196- )
197-
198116 async def _on_track_published (self , event : sfu_events .TrackPublishedEvent ):
199117 """Handle track published events from SFU - spawn TrackAddedEvent with correct type."""
200118 if not event .payload :
@@ -366,13 +284,10 @@ async def join(self, agent: "Agent", call: Call) -> StreamConnection:
366284 This function
367285 - initializes the chat channel
368286 - has the agent.agent_user join the call
369- - connect incoming audio/video to the agent
287+ - connects incoming audio/video to the agent
370288 - connecting agent's outgoing audio/video to the call
371-
372- TODO:
373- - process track flow
374-
375289 """
290+
376291 # Traditional mode - use WebRTC connection
377292 # Configure subscription for audio and video
378293 subscription_config = SubscriptionConfig (
@@ -401,13 +316,16 @@ async def on_audio_received(pcm: PcmData):
401316 )
402317 )
403318
404- await (
405- connection .__aenter__ ()
406- ) # TODO: weird API? there should be a manual version
407- self ._connection = connection
319+ # Re-emit certain events from the underlying RTC stack
320+ # for the Agent to subscribe.
321+ connection .on ("participant_joined" , self .events .send )
322+ connection .on ("participant_left" , self .events .send )
323+ connection .on ("track_published" , self .events .send )
324+ connection .on ("track_unpublished" , self .events .send )
408325
409- # Subscribe to tracks from participants who joined before the agent
410- await self ._subscribe_to_existing_tracks (connection )
326+ # Start the connection
327+ await connection .__aenter__ ()
328+ self ._real_connection = connection
411329
412330 standardize_connection = StreamConnection (connection )
413331 return standardize_connection
0 commit comments