How to write a demo to forward message to peers by using Rust rocket_ws

121 Views Asked by At

I'm learning websocket using Rust rocket, but there are very few references code that I can refer to. The example code provided by rocket_ws can only echo message back to the sender.

#[get("/echo")]
fn echo(ws: ws::WebSocket) -> ws::Channel<'static> {
    ws.channel(move |mut stream| Box::pin(async move {
        while let Some(message) = stream.next().await {
            let _ = stream.send(message?).await;
        }

        Ok(())
    }))
}

But I want to forward messages to other ws client. So would you be so kind to write a little demo just sending message whick is receieved by rocket_ws to all peer client that connected to this ws server.

I tried to use HashMap, but failed. And I think this solution is't correct at all.

use futures::{SinkExt, StreamExt};
use rocket::{get, routes, State, data::IoStream};
use ws::stream::DuplexStream;
use tokio::sync::Mutex;
use std::{net::SocketAddr, sync::{Arc}, collections::HashMap};
type PeersMap = Arc<Mutex<HashMap<SocketAddr, i32>>>;

mod entity;

// static USER_MAP: Arc<Mutex<HashMap<i32, DuplexStream>>> = Arc::new(Mutex::new(HashMap::new()));

#[get("/<user_id>")]
async fn ws_test<'a>(user_id: i32, socket: ws::WebSocket, user_map: &'a State<Arc<Mutex<HashMap<i32, DuplexStream>>>>) -> ws::Channel<'a> {

    socket.channel(move |stream| Box::pin(async move {
        user_map.lock().await.insert(user_id, stream);

        if let Some(stream) = user_map.lock().await.get_mut(&user_id) {
            while let Some(message) = stream.next().await {
                // let target = user_map.lock().unwrap().get(k);
                dbg!(&message);
                match message {
                    Ok(res) => {
                        if let ws::Message::Text(concrete) = res {
                            let received_message: entity::ReceivedMessage = serde_json::from_str(&concrete).unwrap();
                            let target_id = received_message.get_target_id().parse::<i32>().unwrap();
                            // let message = 
                            match user_map.lock().await.get_mut(&target_id) {
                                Some(target_stream) => {
                                    let _ = stream.send(ws::Message::Text(concrete.clone())).await;
                                    let _ = target_stream.send(ws::Message::Text(concrete)).await;
                                },
                                None => println!("only support text message"),
                            };
                        }
                    },
                    Err(err) => println!("{err}"),
                }
            }
        }
        Ok(())
    }))

}

#[rocket::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
    // let peers: PeerMap = Arc::new(Mutex::new(HashMap::new()));
    let user_map: Arc<Mutex<HashMap<i32, DuplexStream>>> = Arc::new(Mutex::new(HashMap::new()));

    rocket::build()
    .manage(user_map)
    .mount("/ws", routes![ws_test])
    .launch().await?;
    Ok(())
}
1

There are 1 best solutions below

0
Looouiiis On

I have find that the struct DuplexStream can be spilt into sink and stream, and I can use them to resceieve and send message seperately. Thus I successfully finish my code:

use futures::{SinkExt, StreamExt, stream::SplitSink};
use rocket::{get, routes, State};
use ws::{stream::DuplexStream, Message};
use tokio::sync::Mutex;
use std::{sync::Arc, collections::HashMap};

mod entity;

mod auth;

#[get("/<user_id>")]
async fn ws_test<'a>(user_id: i32, socket: ws::WebSocket, user_map: &'a State<Arc<Mutex<HashMap<i32, Arc<Mutex<SplitSink<DuplexStream, Message>>>>>>>) -> ws::Channel<'a> {
    
    socket.channel(move |stream| Box::pin(async move {
        let (sender, mut receiever) = stream.split();
        let arc_sender = Arc::new(Mutex::new(sender));

        match user_map.try_lock() {
            Ok(mut res) => {
                res.insert(user_id, arc_sender.clone());
                println!("{user_id}接入成功")
            },
            Err(_) => {
                panic!("向user_map进行插入的时候无法获取互斥锁")
            },
        }
        while let Some(message) = receiever.next().await {
            match message {
                Ok(res) => {
                    if let ws::Message::Text(concrete) = res {
                        let received_message: entity::ReceivedMessage = serde_json::from_str(&concrete).unwrap();
                        let target_id = received_message.get_target_id().parse::<i32>().unwrap();
                        match user_map.try_lock() {
                            Ok(mut map) => {
                                match map.get_mut(&target_id) {
                                    Some(target_sender) => {
                                        let _ = target_sender.lock().await.send(ws::Message::Text(concrete)).await;
                                    },
                                    None => {
                                        let _ = arc_sender.clone().lock().await.send(ws::Message::Text("未找到目标".to_string())).await;
                                    },
                                };
                            },
                            Err(_) => panic!("user_map目前被锁定"),
                        }
                    }
                    else if let ws::Message::Close(_) = res {
                        println!("客户端{user_id}关闭了连接");
                        let _ = arc_sender.lock().await.close();
                        user_map.lock().await.remove(&user_id);
                    }
                },
                Err(_) => {
                    print!("处理关闭时仍在连接的客户端{user_id}...");
                    let _ = arc_sender.lock().await.close();
                    user_map.lock().await.remove(&user_id);
                    println!("完毕");
                },
            }
        };
        Ok(())
    }))

}

#[rocket::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
    let user_map: Arc<Mutex<HashMap<i32, Arc<Mutex<SplitSink<DuplexStream, Message>>>>>> = Arc::new(Mutex::new(HashMap::new()));

    rocket::build()
    .manage(user_map)
    .mount("/ws", routes![ws_test])
    .launch().await?;
    Ok(())
}