11import asyncio
22import logging
3+ from asyncio import TimeoutError
4+ from asyncio .queues import QueueEmpty
35from datetime import datetime , timezone
4- from typing import Optional , Any , Dict , Iterator , AsyncIterator
6+ from typing import Any , AsyncIterator , Dict , Optional
57
8+ from aiobotocore .session import AioSession
69from aiohttp import ClientConnectionError
7- from asyncio import TimeoutError
8- from asyncio .queues import QueueEmpty
910from botocore .exceptions import ClientError
10- from aiobotocore .session import AioSession
1111
12- from .utils import Throttler
1312from .base import Base
14- from .checkpointers import MemoryCheckPointer , CheckPointer
13+ from .checkpointers import CheckPointer , MemoryCheckPointer
1514from .processors import JsonProcessor , Processor
15+ from .utils import Throttler
1616
1717log = logging .getLogger (__name__ )
1818
@@ -106,7 +106,8 @@ async def close(self):
106106
107107 if self .checkpointer :
108108 await self .checkpointer .close ()
109- await self .client .close ()
109+ if self .client is not None :
110+ await self .client .close ()
110111
111112 async def flush (self ):
112113
@@ -123,15 +124,29 @@ async def flush(self):
123124 await shard ["fetch" ]
124125
125126 async def _fetch (self ):
127+ error_count = 0
128+ max_errors = 10
129+
126130 while self .is_fetching :
127131 # Ensure fetch is performed at most 5 times per second (the limit per shard)
128132 await asyncio .sleep (0.2 )
129133 try :
130134 await self .fetch ()
135+ error_count = 0 # Reset error count on successful fetch
131136 except asyncio .CancelledError :
132- pass
137+ log .debug ("Fetch task cancelled" )
138+ self .is_fetching = False
139+ break
133140 except Exception as e :
134141 log .exception (e )
142+ error_count += 1
143+ if error_count >= max_errors :
144+ log .error (
145+ f"Too many fetch errors ({ max_errors } ), stopping fetch task"
146+ )
147+ self .is_fetching = False
148+ break
149+ await asyncio .sleep (min (2 ** error_count , 30 )) # Exponential backoff
135150
136151 async def fetch (self ):
137152
@@ -218,7 +233,13 @@ async def fetch(self):
218233 for n , output in enumerate (
219234 self .processor .parse (row ["Data" ])
220235 ):
221- await self .queue .put (output )
236+ try :
237+ await asyncio .wait_for (
238+ self .queue .put (output ), timeout = 30.0
239+ )
240+ except asyncio .TimeoutError :
241+ log .warning ("Queue put timed out, skipping record" )
242+ continue
222243 total_items += n + 1
223244
224245 # Get approx minutes behind..
@@ -253,14 +274,23 @@ async def fetch(self):
253274
254275 # Add checkpoint record
255276 last_record = result ["Records" ][- 1 ]
256- await self .queue .put (
257- {
258- "__CHECKPOINT__" : {
259- "ShardId" : shard ["ShardId" ],
260- "SequenceNumber" : last_record ["SequenceNumber" ],
261- }
262- }
263- )
277+ try :
278+ await asyncio .wait_for (
279+ self .queue .put (
280+ {
281+ "__CHECKPOINT__" : {
282+ "ShardId" : shard ["ShardId" ],
283+ "SequenceNumber" : last_record [
284+ "SequenceNumber"
285+ ],
286+ }
287+ }
288+ ),
289+ timeout = 30.0 ,
290+ )
291+ except asyncio .TimeoutError :
292+ log .warning ("Checkpoint queue put timed out" )
293+ # Continue without checkpoint - not critical
264294
265295 shard ["LastSequenceNumber" ] = last_record ["SequenceNumber" ]
266296
@@ -302,7 +332,7 @@ async def get_records(self, shard):
302332 shard ["stats" ].succeded ()
303333 return result
304334
305- except ClientConnectionError as e :
335+ except ClientConnectionError :
306336 await self .get_conn ()
307337 except TimeoutError as e :
308338 log .warning ("Timeout {}. sleeping.." .format (e ))
@@ -358,17 +388,17 @@ async def get_shard_iterator(self, shard_id, last_sequence_number=None):
358388
359389 params = {
360390 "ShardId" : shard_id ,
361- "ShardIteratorType" : "AFTER_SEQUENCE_NUMBER"
362- if last_sequence_number
363- else self . iterator_type ,
391+ "ShardIteratorType" : (
392+ "AFTER_SEQUENCE_NUMBER" if last_sequence_number else self . iterator_type
393+ ) ,
364394 }
365395 params .update (self .address )
366396
367397 if last_sequence_number :
368398 params ["StartingSequenceNumber" ] = last_sequence_number
369399
370- if self .iterator_type == ' AT_TIMESTAMP' and self .timestamp :
371- params [' Timestamp' ] = self .timestamp
400+ if self .iterator_type == " AT_TIMESTAMP" and self .timestamp :
401+ params [" Timestamp" ] = self .timestamp
372402
373403 response = await self .client .get_shard_iterator (** params )
374404 return response ["ShardIterator" ]
@@ -397,7 +427,12 @@ async def __anext__(self):
397427 # Raise exception from Fetch Task to main task otherwise raise exception inside
398428 # Fetch Task will fail silently
399429 if self .fetch_task .done ():
400- raise self .fetch_task .exception ()
430+ exception = self .fetch_task .exception ()
431+ if exception :
432+ raise exception
433+
434+ checkpoint_count = 0
435+ max_checkpoints = 100 # Prevent infinite checkpoint processing
401436
402437 while True :
403438 try :
@@ -409,6 +444,12 @@ async def __anext__(self):
409444 item ["__CHECKPOINT__" ]["ShardId" ],
410445 item ["__CHECKPOINT__" ]["SequenceNumber" ],
411446 )
447+ checkpoint_count += 1
448+ if checkpoint_count >= max_checkpoints :
449+ log .warning (
450+ f"Processed { max_checkpoints } checkpoints, stopping iteration"
451+ )
452+ raise StopAsyncIteration
412453 continue
413454
414455 return item
0 commit comments