How can I convert a Python multithreaded socket client/server to asyncio?

65 Views Asked by At

I'm writing a tool to test an asyncio based server end-to-end. Initially I was going to spin up the server in one terminal window, and run the test in a separate window, but then I realized that it should be possible to just do that in one script. After all, I can do it with a concurrent.futures.ThreadPoolExecutor, but I'm struggling converting the logic using await/async def.

Here's a working example using the TPE:

import argparse
import socket
import concurrent.futures
import threading
import socketserver


class TCPHandler(socketserver.BaseRequestHandler):
    def handle(self):
        print(f'Got data: {self.request.recv(1024).strip().decode()}')

def started_server(*, server):
    print('starting server')
    server.serve_forever()
    print('server thread closing')


def run_client(*, host, port, server):
    print('client started, attempting connection')
    with socket.create_connection((host, port), timeout=0.5) as conn:
        print('connected')
        conn.send(b'hello werld')
    print('closing server')
    server.shutdown()
    print('cancelled')

def test_the_server(*, host, port):
    ex = concurrent.futures.ThreadPoolExecutor(max_workers=3)
    print('server a')
    quitter = threading.Event()
    server = socketserver.TCPServer((host, port), TCPHandler)
    a = ex.submit(started_server, server=server)
    b = ex.submit(run_client, host=host, port=port, server=server)
    print(a.result(), b.result())
    print('server b')


def do_it():  # Shia LeBeouf!
    parser = argparse.ArgumentParser(usage=__doc__)
    parser.add_argument("--host", default="127.0.0.1")
    parser.add_argument("-p", "--port", type=int, default=60025)
    args = parser.parse_args()

    exit(test_the_server(host=args.host, port=args.port))


if __name__ == "__main__":
    do_it()

How would I convert this to use an asyncio loop? I'm pretty sure that I need to spawn the server asyncio loop in a thread, but so far it's just turned out blocking, and other questions on SO have failed to provide a solution (or have been outdated).

Here's an example of something that fails for me:

import asyncio
import argparse
import socket
import concurrent.futures
import threading
import socketserver


class EchoHandler(asyncio.Protocol):
    def data_received(self, data):
        print(f"Got this data: {data.decode()}")

async def run_server(*, server):
    print('starting server')
    server = await server
    async with server:
        print('start serving')
        await server.start_serving()
        print('waiting on close')
        await server.wait_closed()
    print('server coro closing')

def started_server(*, server):
    print('server thread started')
    asyncio.run(run_server(server=server))
    print('server thread finished')

def run_client(*, host, port, server):
    print('client started, attempting connection')
    with socket.create_connection((host, port), timeout=0.5) as conn:
        print('connected')
        conn.send(b'hello werld')
    print('closing server')
    server.close()
    print('cancelled')

async def fnord(reader, writer):
    data = await reader.read(100)
    message = data.decode()
    print('got', message)

def test_the_server(*, host, port):
    ex = concurrent.futures.ThreadPoolExecutor(max_workers=3)
    print('server a')
    quitter = threading.Event()
    #server = socketserver.TCPServer((host, port), TCPHandler)
    server = asyncio.start_server(fnord, host, port)
    a = ex.submit(started_server, server=server)
    b = ex.submit(run_client, host=host, port=port, server=server)
    print(a.result(), b.result())
    print('server b')


def do_it():  # Shia LeBeouf!
    parser = argparse.ArgumentParser(usage=__doc__)
    parser.add_argument("--host", default="127.0.0.1")
    parser.add_argument("-p", "--port", type=int, default=60025)
    args = parser.parse_args()

    exit(test_the_server(host=args.host, port=args.port))


if __name__ == "__main__":
    do_it()

I was hoping that https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.Server.wait_closed would be enough that when I call server.close() on the other thread that it would shut down the server, but it doesn't appear to be the case. serve_forever behaves the same as the start_serving approach.

4

There are 4 best solutions below

0
bmitc On

As I mentioned in my comments, I'm not exactly for sure what you're looking for, but I thought I'd come up with an example in which you are running both a TCP server and client in the same asyncio event loop. I augmented the examples for asyncio streams to make this example.

import asyncio


async def test_client(message):
    reader, writer = await asyncio.open_connection("127.0.0.1", 9000)

    writer.write(f"{message}\n".encode())
    await writer.drain()
    print(f"Test client sent the message: {message}")

    data = await reader.readline()
    response = data.decode().strip()
    print(f"Test client received the response: {response}")

    writer.close()
    await writer.wait_closed()


async def handle_echo(reader, writer):
    data = await reader.readline()
    message = data.decode().strip()
    print(f"Server received message: {message}")

    writer.write(f"{message}\n".encode())
    await writer.drain()

    writer.close()
    await writer.wait_closed()


async def server():
    server = await asyncio.start_server(handle_echo, "127.0.0.1", 9000)

    async with server:
        await server.serve_forever()


async def main():
    await asyncio.gather(
        server(),
        test_client("This is only a test"),
    )


asyncio.run(main())

This example doesn't make for a great script though, because the server will run forever until ctrl + c. This example prints:

Test client sent the message: This is only a test
Server received message: This is only a test
Test client received the response: This is only a test

Here's an example using tasks directly (note that asyncio.gather turns the coroutines it is passed into tasks) which allows you to cancel a task if you have a reference to it. Everything is the same except for the change to main and the extra coroutine.

async def cancel_task_after(task: asyncio.Task, seconds_to_wait: int) -> None:
    await asyncio.sleep(seconds_to_wait)
    task.cancel()
    return None


async def main():
    async with asyncio.TaskGroup() as task_group:
        server_task = task_group.create_task(server())
        client_task = task_group.create_task(test_client("This is only a test"))
        cancel_server_task = task_group.create_task(cancel_task_after(server_task, 3))


asyncio.run(main())
0
Andrej Kesely On

Here is an example how you can run server/client in one script (and on one asyncio loop):

import asyncio


class EchoServerProtocol(asyncio.Protocol):
    def connection_made(self, transport):
        peername = transport.get_extra_info("peername")
        print("[Server] Connection from {}".format(peername))
        self.transport = transport

    def data_received(self, data):
        message = data.decode()
        print("[Server] Data received: {!r}".format(message))

        print("[Server] Send: {!r}".format(message))
        self.transport.write(data)

        print("[Server] Close the client socket")
        self.transport.close()


class EchoClientProtocol(asyncio.Protocol):
    def __init__(self, message, on_con_lost):
        self.message = message
        self.on_con_lost = on_con_lost

    def connection_made(self, transport):
        transport.write(self.message.encode())
        print("[Client] Data sent: {!r}".format(self.message))

    def data_received(self, data):
        print("[Client] Data received: {!r}".format(data.decode()))

    def connection_lost(self, exc):
        print("[Client] The server closed the connection")
        self.on_con_lost.set_result(True)


async def main():
    loop = asyncio.get_running_loop()

    server = await loop.create_server(lambda: EchoServerProtocol(), "127.0.0.1", 8888)

    async with server:
        on_con_lost = loop.create_future()

        transport, protocol = await loop.create_connection(
            lambda: EchoClientProtocol("Hello World", on_con_lost), "127.0.0.1", 8888
        )

        await on_con_lost


if __name__ == "__main__":
    asyncio.run(main())

When you run this script it will print:

[Server] Connection from ('127.0.0.1', 41438)
[Client] Data sent: 'Hello World'
[Server] Data received: 'Hello World'
[Server] Send: 'Hello World'
[Server] Close the client socket
[Client] Data received: 'Hello World'
[Client] The server closed the connection
5
Booboo On

Update

The issue with your code, if I have correctly understood your post, is that you are executing server = asyncio.start_server(fnord, host, port) but the returned server value is a Future instance you can await for the completion of the server. It is, however, not something you can call close on.

If you want to terminate the server cleanly by calling close, then you should use the loop method create_server as in:

server = asyncio.get_running_loop().create_server(...)

The call to create_server, however, requires that it be an asyncio.Protocol factory).

The following code leaves the client unmodified and since it is not based on asyncio coroutines, it is being run in a thread pool (a multiprocessing pool would be preferred if the client was CPU-intensive). If the command line does not contain the -t or --test flag, then the server will be run as a normal server awaiting client connections until the Enter key is hit:

import asyncio
import aioconsole
import concurrent.futures
import socket


class EchoHandler(asyncio.Protocol):
    def connection_made(self, transport):
        self._transport = transport

    def data_received(self, data):
        print('Server received:', data.decode())
        self._transport.write(data)


async def start_server(server, server_started):
    print('starting server')
    async with server:
        print('start serving')
        await server.start_serving()
        # Notify we have started
        server_started.set()
        print('waiting on close')
        await server.wait_closed()
    print('server coro closing')


async def run_server(host, port):
    server_started = asyncio.Event()
    loop = asyncio.get_running_loop()
    server = await loop.create_server(EchoHandler, host, port)
    server_task = loop.create_task(start_server(server, server_started))
    # Wait for the server to start
    await server_started.wait()
    return server_task, server


def run_client(host, port):
    print('client started, attempting connection')
    with socket.create_connection((host, port), timeout=0.5) as conn:
        print('connected')
        conn.send(b'Hello world!')
        reply = conn.recv(100).decode()
        print('Client reply was:', reply)


async def serve(host, port, test=False):
    loop = asyncio.get_running_loop()
    server_task, server = await run_server(host, port)

    if test:
        # Run the client, await its completion
        # and then shutdown:
        with concurrent.futures.ThreadPoolExecutor(1) as executor:
            await loop.run_in_executor(executor, run_client, host, port)
    else:
        await aioconsole.ainput('Hit Enter to terminate ...')

    # Now shudtown the server and await its completion:
    server.close()
    await server_task


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(usage=__doc__)
    parser.add_argument("-t", "--test", action='store_const', const=True, default=False)
    parser.add_argument("--host", default="127.0.0.1")
    parser.add_argument("-p", "--port", type=int, default=60025)
    args = parser.parse_args()

    asyncio.run(serve(host=args.host, port=args.port, test=args.test))

Prints:

starting server
start serving
waiting on close
client started, attempting connection
connected
Server received: Hello world!
Client reply was: Hello world!
server coro closing
0
Wayne Werner On

The problem here is that async def code is expected to be called and executed from within other async def code, aside from a few exceptions, such as asyncio.run.

Generally speaking, asyncio manages its own threading:

Almost all asyncio objects are not thread safe, which is typically not a problem unless there is code that works with them from outside of a Task or a callback.

In my experiments here in early-mid 2024, Python does not offer a simple way to extract socket-based objects in a functional status from an asyncio loop. However, this functionality is possible when using low-level primitives such as asyncio.new_event_loop. What we're doing here is nothing more than what asyncio manages for us, that is, cooperative multi-tasking -- here our functionality is IO bound (i.e. we're not waiting on the CPU, we're waiting on sockets)

I was attempting to use threads, a la:

sync/normal python -,---> asyncio  <------, (server.close)
                     `--> threaded code --'

Which didn't work with async def|with/await code. It seems like it should be, since I could return results from async def functions like this:

async def the_answer():
    return 42

def the_question():
    result = asyncio.run()
    print("The answer to life, the universe, and everything is", result)  

if __name__ == "__main__":
    the_question()

If I can return 42, it would make sense to think that one could generate a server, and start the server listening this way. And indeed, you can, sort-of!

class EchoHandler(asyncio.Protocol):                                            
    def connection_made(self, transport):                                       
        self._transport = transport                                             
                                                                                
    def data_received(self, data):                                              
        print('Server received:', data.decode())                                
        self._transport.write(data)                                             
                                                                                
async def the_server(loop, host, port):                                         
    server = await loop.create_server(EchoHandler, host, port)                  
    await asyncio.sleep(10)                                                     
    return server                                                               
                                                                                                                                                       
def the_question():                                                             
    host = '127.0.0.1'                                                          
    port = 60025                                                                
    loop = asyncio.new_event_loop()                                             
    server = loop.run_until_complete(the_server(loop, host, port))              
    print("This is my server", server)                                          
    time.sleep(5)                                                               

While the_server is awaiting on asyncio.sleep, it will print data that's received. However, once run_until_complete finishes, the server will no longer respond to connection attempts (with a ConnectionRefusedError when called via a multithreaded approach).

However, I noticed that the server was printed with an existing socket:

This is my server <Server sockets=(<asyncio.TransportSocket fd=6, family=2, type=1, proto=6, laddr=('127.0.0.1', 60025)>,)>

And connections to my server via netcat weren't failing, they were just apparently hanging open, and nothing was being printed from the echo server. What do I know about async python?

Well:

  1. I know that async def and await are just syntactic sugar.
    >>> async def foo():
    ...  return 42
    ... 
    >>> x = foo()
    >>> x
    <coroutine object foo at 0x7f63a9d6ca90>
    >>> x.send(None)
    Traceback (most recent call last):
      File "<stdin>", line 1, in <module>
    StopIteration: 42
    
  2. I know that asyncio is only one of many async libraries.
  3. I know that it uses its own threading to manage processing.
  4. I know that it's cooperative multitasking.

I tried creating a simple connection from within this same thread:

    print("This is my server", server)                                          
    print(loop.is_closed())                                                                                                                            
    with socket.create_connection((host, port), timeout=0.5) as conn:           
        print('connected?')                                                     
        conn.send(b'hello werld')                                                      
        print('done')                                                           
    time.sleep(5)                                                               
    print('done')  

And the connection succeeded - no timeout error, no connection refused. But no message printed either. So now I know, my server is still running. But not, right? In fact, this failed:

        conn.send(b'hello werld')                                               
        print(conn.recv(2048))                                                  
        print('done') 

And it failed with a TimeoutError! That means that for sure the connection is working, it just failed to receive a response from the server in time. Bumping up the timeout didn't help.

At this point I realized the loop isn't running but it also isn't closed so how can I tell asyncio that it's OK to process all the waiting tasks?

there's run and run_until_complete but I need some tasks for that, because I don't want to run_forever. And then I realized: asyncio.sleep

always suspends the current task, allowing other tasks to run.

Which... is exactly what I wanted. It does say

Setting the delay to 0 provides an optimized path to allow other tasks to run. This can be used by long-running functions to avoid blocking the event loop for the full duration of the function call.

But doing this seemed to end out closing the loop instead.

It is possible to combine our threaded code with asyncio, using some of the lower level code. However, if we're using threading instead of multiprocessing, it appears that the GIL still interferes and we'll have to ensure cooperation. The process looks like this:

  1. Run sync code as normal, and create a threading.Event
  2. Spawn server and client threads, passing the event and any other necessary info.
  3. In server thread, create a new async loop and server. While event isn't set, keep having async loop wait on asyncio.sleep (which will let it process background tasks)
  4. On client, connect to server and send messages - it may be necessary to time.sleep, though that might also just be because the server reads too quickly. Set the event when we're done sending messages.
  5. On server thread, either close loop or just let it close when we exit the thread now that the event is set.

Here's an example that works with either threads or multiprocessing:

import asyncio                                                                  
import argparse                                                                 
import time                                                                     
import socket                                                                   
import concurrent.futures                                                       
import threading                                                                
import multiprocessing                                                          
                                                                                
                                                                                
class EchoHandler(asyncio.Protocol):                                            
    def data_received(self, data):                                              
        print(f"Got this data: {data.decode()}")                                
                                                                                
def launch_server(host, port):                                                  
    print("starting loop")                                                      
    loop = asyncio.new_event_loop()                                             
    print("serving server")                                                     
    server = loop.run_until_complete(loop.create_server(EchoHandler, host, port, start_serving=True))
    print("server launched")                                                    
    while not quit_event.wait(0.01):                                            
        loop.run_until_complete(asyncio.sleep(0))                               
    loop.close()                                                                
    return 0                                                                    
                                                                                
def launch_client(host, port):                                                  
    print("client launched")                                                    
    time.sleep(0.1)                                                             
    for _ in range(3):                                                          
        print("connecting")                                                     
        try:                                                                    
            with socket.create_connection((host, port), timeout=0.5) as conn:   
                print('connected & sending')                                    
                for i in range(1,4):                                            
                    conn.send(f'hello werld {i}'.encode())                      
                    time.sleep(0.01)                                            
                time.sleep(0.1)                                                 
                print("client sent data")                                       
        except Exception as ex:                                                 
            print("oops", ex)                                                   
    quit_event.set()                                                            
    return 0                                                                    
                                                                                
def init_event(event):                                                          
    global quit_event                                                           
    quit_event = event                                                          
                                                                                
def test_the_server(*, host, port):                                             
    quit_event = multiprocessing.Event()                                        
    #quit_event = threading.Event()                                             
    #with concurrent.futures.ThreadPoolExecutor(2) as ex:                       
    with concurrent.futures.ProcessPoolExecutor(2, initializer=init_event, initargs=(quit_event,)) as ex:
        client_thread = ex.submit(launch_client, host=host, port=port)          
        server_thread = ex.submit(launch_server, host=host, port=port)          
    return client_thread.result() + server_thread.result()                      
                                                                                
                                                                                
def do_it():  # Shia LeBeouf!                                                   
    parser = argparse.ArgumentParser(usage=__doc__)                             
    parser.add_argument("--host", default="127.0.0.1")                          
    parser.add_argument("-p", "--port", type=int, default=60025)                
    args = parser.parse_args()                                                  
                                                                                
    exit(test_the_server(host=args.host, port=args.port))                       
                                                                                
                                                                                
if __name__ == "__main__":                                                      
    do_it()