How can I iterate over an AsyncIterator stream in Python with a timeout, without cancelling the stream?

22 Views Asked by At

I'm dealing with an object that is an AsyncIterator[str]. It gets messages from the network, and yields them as strings. I want to create a wrapper for this stream that buffers these messages, and yields them at a regular interval.

My code looks like this:

async def buffer_stream(stream: AsyncIterator[str], buffer_time: Optional[float]) -> AsyncIterator[str]:
    """
    Buffer messages from the stream, and yields them at regular intervals.
    """
    last_sent_at = time.perf_counter()
    buffer = ''

    stop = False
    while not stop:
        time_to_send = False

        timeout = (
            max(buffer_time - (time.perf_counter() - last_sent_at), 0)
            if buffer_time else None
        )
        try:
            buffer += await asyncio.wait_for(
                stream.__anext__(),
                timeout=timeout
            )
        except asyncio.TimeoutError:
            time_to_send = True
        except StopAsyncIteration:
            time_to_send = True
            stop = True
        else:
            if time.perf_counter() - last_sent_at >= buffer_time:
                time_to_send = True

        if not buffer_time or time_to_send:
            if buffer:
                yield buffer
                buffer = ''
            last_sent_at = time.perf_counter()

As far as I can tell, the logic makes sense, but as soon as it hits the first timeout, it interrupts the stream, and exits early, before the stream is done.

I think this might be because asyncio.wait_for specifically says:

When a timeout occurs, it cancels the task and raises TimeoutError. To avoid the task cancellation, warp it in shield().

I tried wrapping it in shield:

buffer += await asyncio.wait_for(
    shield(stream.__anext__()),
    timeout=timeout
)

This errors out for a different reason: RuntimeError: anext(): asynchronous generator is already running. From what I understand, that means that it's still in the process of getting the previous anext() when it tries to get the next one, which causes an error.

Is there a proper way to do this?

Demo: https://www.sololearn.com/en/compiler-playground/cBCVnVAD4H7g

1

There are 1 best solutions below

0
user4815162342 On BEST ANSWER

You can turn the result of stream.__anext__() into a task (or, more generally, a future) and await it until it times out or yields a result:

async def buffer_stream(stream: AsyncIterator[str], buffer_time: Optional[float]) -> AsyncIterator[str]:
    last_sent_at = time.perf_counter()
    buffer = ''

    stop = False
    await_next = None
    while not stop:
        time_to_send = False

        timeout = (
            max(buffer_time - (time.perf_counter() - last_sent_at), 0)
            if buffer_time else None
        )
        if await_next is None:
            await_next = asyncio.ensure_future(stream.__anext__())
        try:
            buffer += await asyncio.wait_for(
                asyncio.shield(await_next),
                timeout=timeout
            )
        except asyncio.TimeoutError:
            time_to_send = True
        except StopAsyncIteration:
            time_to_send = True
            stop = True
        else:
            await_next = None
            if time.perf_counter() - last_sent_at >= buffer_time:
                time_to_send = True

        if not buffer_time or time_to_send:
            if buffer:
                yield buffer
                buffer = ''
            last_sent_at = time.perf_counter()