-
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathws_server.rs
More file actions
187 lines (162 loc) · 5.66 KB
/
ws_server.rs
File metadata and controls
187 lines (162 loc) · 5.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
/*
* WebSocket Server
*
* PURPOSE:
* This high-performance WebSocket server acts as the communication bridge between
* the RL simulation backend and the frontend visualization. It efficiently handles
* real-time data streaming with support for thousands of concurrent connections.
*
* KEY FUNCTIONS:
* - Maintains connection state for all connected clients
* - Broadcasts simulation updates to all clients efficiently
* - Provides immediate state synchronization for new connections
* - Handles parameter updates from clients back to the simulation
* - Uses Axum and Tokio for asynchronous, non-blocking I/O
*/
// ws_server.rs
use axum::{
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
State,
},
response::IntoResponse,
routing::get,
Router,
};
use futures::{SinkExt, StreamExt};
use std::{collections::HashMap, sync::Arc};
use tokio::sync::{mpsc, RwLock};
use tower_http::cors::CorsLayer;
// User connection state
struct AppState {
// Map from user ID to sender channel
users: RwLock<HashMap<String, mpsc::Sender<Message>>>,
// Store latest simulation state to send to new connections
latest_state: RwLock<SimulationState>,
}
// Simplified simulation state structure
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
struct SimulationState {
timestamp: u64,
traffic_data: Vec<TrafficNode>,
energy_data: Vec<EnergyNode>,
waste_data: Vec<WasteNode>,
// Additional fields as needed
}
// Handlers for different node types
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
struct TrafficNode {
id: String,
congestion: f32,
// Other traffic-specific fields
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
struct EnergyNode {
id: String,
load: f32,
// Other energy-specific fields
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
struct WasteNode {
id: String,
fill_level: f32,
// Other waste-specific fields
}
async fn websocket_handler(
ws: WebSocketUpgrade,
State(state): State<Arc<AppState>>,
) -> impl IntoResponse {
ws.on_upgrade(|socket| handle_socket(socket, state))
}
async fn handle_socket(socket: WebSocket, state: Arc<AppState>) {
let (mut sender, mut receiver) = socket.split();
// Generate a unique user ID
let user_id = uuid::Uuid::new_v4().to_string();
// Create a channel for sending messages to this user
let (tx, mut rx) = mpsc::channel::<Message>(100);
// Store the sender in our user map
state.users.write().await.insert(user_id.clone(), tx);
// Send the latest state immediately to new connections
let latest_state = state.latest_state.read().await.clone();
let state_json = serde_json::to_string(&latest_state).unwrap();
let _ = sender.send(Message::Text(state_json)).await;
// Handle incoming messages
let mut recv_task = tokio::spawn(async move {
while let Some(Ok(msg)) = receiver.next().await {
match msg {
Message::Text(text) => {
// Handle parameter updates from client
if let Ok(params) = serde_json::from_str::<SimulationParameters>(&text) {
// Forward parameters to simulation controller
// ...
}
}
Message::Close(_) => break,
_ => {}
}
}
// Remove user when connection closes
state.users.write().await.remove(&user_id);
});
// Forward messages from the channel to the WebSocket
let mut send_task = tokio::spawn(async move {
while let Some(msg) = rx.recv().await {
if sender.send(msg).await.is_err() {
break;
}
}
});
// Wait for either task to finish
tokio::select! {
_ = (&mut recv_task) => send_task.abort(),
_ = (&mut send_task) => recv_task.abort(),
}
}
// Handler for simulation updates (called by your RL system)
async fn update_simulation_state(
state: Arc<AppState>,
new_state: SimulationState,
) {
// Update the latest state
*state.latest_state.write().await = new_state.clone();
// Serialize the state once
let state_json = serde_json::to_string(&new_state).unwrap();
let msg = Message::Text(state_json);
// Send to all connected clients
for (_, tx) in state.users.read().await.iter() {
// Non-blocking send, ignore errors (they'll be cleaned up when the connection drops)
let _ = tx.try_send(msg.clone());
}
}
// Initialize the Axum server
#[tokio::main]
async fn main() {
// Initialize state
let state = Arc::new(AppState {
users: RwLock::new(HashMap::new()),
latest_state: RwLock::new(SimulationState {
timestamp: 0,
traffic_data: vec![],
energy_data: vec![],
waste_data: vec![],
}),
});
// Create a channel for simulation updates
let (update_tx, mut update_rx) = mpsc::channel::<SimulationState>(100);
// Clone state for the update handler
let update_state = state.clone();
// Spawn a task to handle simulation updates
tokio::spawn(async move {
while let Some(new_state) = update_rx.recv().await {
update_simulation_state(update_state.clone(), new_state).await;
}
});
// Build the router
let app = Router::new()
.route("/ws", get(websocket_handler))
.layer(CorsLayer::permissive())
.with_state(state);
// Start the server
let listener = tokio::net::TcpListener::bind("0.0.0.0:3001").await.unwrap();
axum::serve(listener, app).await.unwrap();
}