use crate::dto::{ ApiCatalogResponse, ApiCommandRequest, ApiCommandResponse, ApiDiscoveredNodeType, ApiDiscoveryResult, ApiDiscoveryScanRequest, ApiDiscoveryScanResponse, ApiErrorResponse, ApiGroupListResponse, ApiPresetListResponse, ApiPreviewResponse, ApiSnapshotResponse, ApiStateResponse, ApiStateSnapshot, ApiStreamEnvelope, ApiStreamMessage, API_VERSION, }; use crate::websocket::{websocket_accept_value, write_text_frame}; use infinity_host::HostApiPort; use serde_json::Value; use std::collections::HashMap; use std::io::{self, Read, Write}; use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpListener, TcpStream}; use std::sync::{ atomic::{AtomicBool, Ordering}, mpsc, Arc, Mutex, }; use std::thread::{self, JoinHandle}; use std::time::{Duration, Instant}; pub struct HostApiServer { local_addr: SocketAddr, shutdown: Arc, accept_thread: Option>, } #[derive(Debug)] struct ApiRequestError { status: u16, code: String, message: String, } impl HostApiServer { pub fn bind(bind: &str, service: Arc) -> io::Result { let listener = TcpListener::bind(bind)?; listener.set_nonblocking(true)?; let local_addr = listener.local_addr()?; let shutdown = Arc::new(AtomicBool::new(false)); let thread_shutdown = Arc::clone(&shutdown); let accept_thread = thread::spawn(move || accept_loop(listener, service, thread_shutdown)); Ok(Self { local_addr, shutdown, accept_thread: Some(accept_thread), }) } pub fn local_addr(&self) -> SocketAddr { self.local_addr } pub fn shutdown(mut self) { self.shutdown.store(true, Ordering::SeqCst); if let Some(handle) = self.accept_thread.take() { let _ = handle.join(); } } } impl Drop for HostApiServer { fn drop(&mut self) { self.shutdown.store(true, Ordering::SeqCst); if let Some(handle) = self.accept_thread.take() { let _ = handle.join(); } } } fn accept_loop(listener: TcpListener, service: Arc, shutdown: Arc) { while !shutdown.load(Ordering::SeqCst) { match listener.accept() { Ok((stream, _)) => { let service = Arc::clone(&service); thread::spawn(move || { let _ = handle_connection(stream, service); }); } Err(error) if error.kind() == io::ErrorKind::WouldBlock => { thread::sleep(Duration::from_millis(25)); } Err(_) => break, } } } fn handle_connection(mut stream: TcpStream, service: Arc) -> io::Result<()> { stream.set_read_timeout(Some(Duration::from_secs(2)))?; let request = read_request(&mut stream)?; if request.path == "/api/v1/stream" && request.is_websocket() { return handle_websocket(stream, request, service); } match (request.method.as_str(), request.path.as_str()) { ("GET", "/api/v1/snapshot") => { let snapshot = service.snapshot(); respond_json( &mut stream, 200, &ApiSnapshotResponse::from_snapshot(&snapshot), ) } ("GET", "/api/v1/state") => { let snapshot = service.snapshot(); respond_json( &mut stream, 200, &ApiStateResponse::from_snapshot(&snapshot), ) } ("GET", "/api/v1/preview") => { let snapshot = service.snapshot(); respond_json( &mut stream, 200, &ApiPreviewResponse::from_snapshot(&snapshot), ) } ("GET", "/api/v1/catalog") => { let snapshot = service.snapshot(); respond_json( &mut stream, 200, &ApiCatalogResponse::from_snapshot(&snapshot), ) } ("GET", "/api/v1/presets") => { let snapshot = service.snapshot(); respond_json( &mut stream, 200, &ApiPresetListResponse::from_snapshot(&snapshot), ) } ("GET", "/api/v1/groups") => { let snapshot = service.snapshot(); respond_json( &mut stream, 200, &ApiGroupListResponse::from_snapshot(&snapshot), ) } ("POST", "/api/v1/command") => match handle_command_post(&mut stream, request, service) { Ok(()) => Ok(()), Err(error) => respond_error(&mut stream, error.status, error.code, error.message), }, ("POST", "/api/v1/discovery/scan") => { match handle_discovery_scan_post(&mut stream, request) { Ok(()) => Ok(()), Err(error) => respond_error(&mut stream, error.status, error.code, error.message), } } ("GET", "/") => respond_text( &mut stream, 200, "text/html; charset=utf-8", include_str!("../../../web/v1/index.html"), ), ("GET", "/index.html") => respond_text( &mut stream, 200, "text/html; charset=utf-8", include_str!("../../../web/v1/index.html"), ), ("GET", "/technical") => respond_text( &mut stream, 200, "text/html; charset=utf-8", include_str!("../../../web/v1/technical.html"), ), ("GET", "/technical.html") => respond_text( &mut stream, 200, "text/html; charset=utf-8", include_str!("../../../web/v1/technical.html"), ), ("GET", "/app.js") => respond_text( &mut stream, 200, "application/javascript; charset=utf-8", include_str!("../../../web/v1/app.js"), ), ("GET", "/technical.js") => respond_text( &mut stream, 200, "application/javascript; charset=utf-8", include_str!("../../../web/v1/technical.js"), ), ("GET", "/styles.css") => respond_text( &mut stream, 200, "text/css; charset=utf-8", include_str!("../../../web/v1/styles.css"), ), _ => respond_text( &mut stream, 404, "application/json; charset=utf-8", &serde_json::to_string_pretty(&ApiErrorResponse::new( "not_found", format!( "no route registered for {} {}", request.method, request.path ), )) .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?, ), } } fn handle_command_post( stream: &mut TcpStream, request: HttpRequest, service: Arc, ) -> Result<(), ApiRequestError> { let parsed = serde_json::from_slice::(&request.body).map_err(|error| { ApiRequestError { status: 400, code: "invalid_request_json".to_string(), message: format!("command request body could not be parsed: {error}"), } })?; let request_id = parsed.request_id.clone(); let command_type = parsed.command.kind_label().to_string(); let command = parsed .into_host_command() .map_err(|error| ApiRequestError { status: 400, code: "invalid_command".to_string(), message: error, })?; let outcome = service .send_command(command) .map_err(|error| ApiRequestError { status: 400, code: error.code, message: error.message, })?; respond_json( stream, 200, &ApiCommandResponse { api_version: API_VERSION, accepted: true, request_id, generated_at_millis: outcome.generated_at_millis, command_type, summary: outcome.summary, }, ) .map_err(|error| ApiRequestError { status: 500, code: "response_write_failed".to_string(), message: error.to_string(), }) } fn handle_websocket( mut stream: TcpStream, request: HttpRequest, service: Arc, ) -> io::Result<()> { let Some(key) = request.header("sec-websocket-key") else { return respond_error( &mut stream, 400, "missing_websocket_key", "websocket upgrade requires sec-websocket-key", ); }; let accept = websocket_accept_value(key); let response = format!( "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: {accept}\r\n\r\n" ); stream.write_all(response.as_bytes())?; let mut sequence = 1u64; let mut last_event_millis = None::; let mut last_event_signatures = Vec::<(Option, String)>::new(); let mut last_streamed_preview = None::; loop { let snapshot = service.snapshot(); send_stream_message( &mut stream, sequence, snapshot.generated_at_millis, ApiStreamMessage::Snapshot(ApiStateSnapshot::from_snapshot(&snapshot)), )?; sequence += 1; let preview_payload = crate::dto::ApiPreviewSnapshot::from_snapshot(&snapshot); if last_streamed_preview .as_ref() .map(|previous| previous != &preview_payload) .unwrap_or(true) { send_stream_message( &mut stream, sequence, snapshot.generated_at_millis, ApiStreamMessage::Preview(preview_payload.clone()), )?; sequence += 1; last_streamed_preview = Some(preview_payload); } let mut new_events = snapshot .recent_events .iter() .filter(|event| match last_event_millis { None => true, Some(last_millis) if event.at_millis > last_millis => true, Some(last_millis) if event.at_millis == last_millis => !last_event_signatures .iter() .any(|signature| signature.0 == event.code && signature.1 == event.message), Some(_) => false, }) .cloned() .collect::>(); new_events.sort_by_key(|event| event.at_millis); for event in new_events { let event_millis = event.at_millis; let current_signature = (event.code.clone(), event.message.clone()); send_stream_message( &mut stream, sequence, event_millis, ApiStreamMessage::Event(event.into()), )?; sequence += 1; match last_event_millis { Some(last_millis) if last_millis == event_millis => { last_event_signatures.push(current_signature); } _ => { last_event_millis = Some(event_millis); last_event_signatures = vec![current_signature]; } } } thread::sleep(Duration::from_millis(250)); } } fn handle_discovery_scan_post( stream: &mut TcpStream, request: HttpRequest, ) -> Result<(), ApiRequestError> { let parsed = serde_json::from_slice::(&request.body).map_err(|error| { ApiRequestError { status: 400, code: "invalid_request_json".to_string(), message: format!("discovery request body could not be parsed: {error}"), } })?; let targets = parse_subnet_targets(&parsed.subnet).map_err(|message| ApiRequestError { status: 400, code: "invalid_subnet_cidr".to_string(), message, })?; let started_at = Instant::now(); let mut results = scan_subnet_targets(&targets); results.sort_by_key(|result| { result .ip .parse::() .map(u32::from) .unwrap_or_default() }); let reachable_hosts = results.iter().filter(|result| result.reachable).count(); respond_json( stream, 200, &ApiDiscoveryScanResponse { api_version: API_VERSION, subnet: parsed.subnet.trim().to_string(), scanned_hosts: targets.len(), reachable_hosts, results, }, ) .map_err(|error| ApiRequestError { status: 500, code: "response_write_failed".to_string(), message: format!( "discovery response could not be written after {} ms: {error}", started_at.elapsed().as_millis() ), }) } fn parse_subnet_targets(raw_subnet: &str) -> Result, String> { const MAX_SCAN_HOSTS: u64 = 1024; let subnet = raw_subnet.trim(); let (address, prefix) = subnet .split_once('/') .ok_or_else(|| format!("subnet '{subnet}' must be in CIDR form, e.g. 192.168.40.0/24"))?; let ip = address .trim() .parse::() .map_err(|_| format!("subnet '{subnet}' contains an invalid IPv4 address"))?; let prefix = prefix .trim() .parse::() .map_err(|_| format!("subnet '{subnet}' contains an invalid CIDR prefix"))?; if prefix > 32 { return Err(format!( "subnet '{subnet}' has prefix {prefix}, expected 0..=32" )); } let host_span = 1u64 << (32u8.saturating_sub(prefix)); if host_span > MAX_SCAN_HOSTS { return Err(format!( "subnet '{subnet}' spans {host_span} addresses, limit is {MAX_SCAN_HOSTS}" )); } let ip_u32 = u32::from(ip); let mask = if prefix == 0 { 0 } else { u32::MAX << (32 - u32::from(prefix)) }; let network = ip_u32 & mask; let broadcast = network | !mask; let (start, end) = if prefix >= 31 { (network, broadcast) } else { (network.saturating_add(1), broadcast.saturating_sub(1)) }; if start > end { return Err(format!("subnet '{subnet}' has no scanable host addresses")); } Ok((start..=end).map(Ipv4Addr::from).collect()) } fn scan_subnet_targets(targets: &[Ipv4Addr]) -> Vec { if targets.is_empty() { return Vec::new(); } let worker_count = usize::min(32, targets.len().max(1)); let (job_sender, job_receiver) = mpsc::channel::(); let job_receiver = Arc::new(Mutex::new(job_receiver)); let (result_sender, result_receiver) = mpsc::channel::(); let mut handles = Vec::with_capacity(worker_count); for _ in 0..worker_count { let receiver = Arc::clone(&job_receiver); let sender = result_sender.clone(); handles.push(thread::spawn(move || loop { let next_job = { let guard = receiver.lock(); match guard { Ok(receiver) => receiver.recv().ok(), Err(_) => None, } }; let Some(ip) = next_job else { break; }; let _ = sender.send(probe_ip(ip)); })); } drop(result_sender); for ip in targets { let _ = job_sender.send(*ip); } drop(job_sender); let mut results = Vec::with_capacity(targets.len()); for _ in 0..targets.len() { if let Ok(result) = result_receiver.recv() { results.push(result); } } for handle in handles { let _ = handle.join(); } results } fn probe_ip(ip: Ipv4Addr) -> ApiDiscoveryResult { let mut reachable = false; let mut detected_type = ApiDiscoveredNodeType::Unknown; let mut hostname = None; if let Some(info_probe) = probe_http_endpoint(ip, 80, "/json/info") { reachable = true; detected_type = detect_node_type(&info_probe.body, detected_type); hostname = extract_probe_hostname(&info_probe); } else if can_connect(ip, 80) { reachable = true; } if !reachable && can_connect(ip, 81) { reachable = true; } if detected_type == ApiDiscoveredNodeType::Unknown { if let Some(node_probe) = probe_http_endpoint(ip, 80, "/api/v1/node/info") { reachable = true; detected_type = detect_node_type(&node_probe.body, detected_type); if hostname.is_none() { hostname = extract_probe_hostname(&node_probe); } } else if let Some(state_probe) = probe_http_endpoint(ip, 80, "/api/v1/state") { reachable = true; detected_type = detect_node_type(&state_probe.body, detected_type); if hostname.is_none() { hostname = extract_probe_hostname(&state_probe); } } } ApiDiscoveryResult { ip: ip.to_string(), reachable, detected_type, hostname, } } fn can_connect(ip: Ipv4Addr, port: u16) -> bool { let address = SocketAddr::new(IpAddr::V4(ip), port); TcpStream::connect_timeout(&address, Duration::from_millis(120)).is_ok() } #[derive(Debug)] struct HttpProbe { headers: HashMap, body: String, } fn probe_http_endpoint(ip: Ipv4Addr, port: u16, path: &str) -> Option { let address = SocketAddr::new(IpAddr::V4(ip), port); let mut stream = TcpStream::connect_timeout(&address, Duration::from_millis(120)).ok()?; let _ = stream.set_read_timeout(Some(Duration::from_millis(180))); let _ = stream.set_write_timeout(Some(Duration::from_millis(120))); let request = format!( "GET {path} HTTP/1.1\r\nHost: {ip}\r\nConnection: close\r\nAccept: application/json\r\n\r\n" ); stream.write_all(request.as_bytes()).ok()?; let mut raw = Vec::new(); stream.read_to_end(&mut raw).ok()?; let header_end = find_header_end(&raw)?; let header_text = String::from_utf8_lossy(&raw[..header_end]); let headers = header_text .lines() .skip(1) .filter_map(|line| line.split_once(':')) .map(|(key, value)| (key.trim().to_ascii_lowercase(), value.trim().to_string())) .collect::>(); let body = String::from_utf8_lossy(raw.get(header_end + 4..).unwrap_or_default()).to_string(); Some(HttpProbe { headers, body }) } fn detect_node_type(body: &str, fallback: ApiDiscoveredNodeType) -> ApiDiscoveredNodeType { let lowered = body.to_ascii_lowercase(); if lowered.contains("\"wled\"") || lowered.contains("\"brand\":\"wled\"") || lowered.contains("\"product\":\"wled\"") { return ApiDiscoveredNodeType::Wled; } if lowered.contains("\"native_node\"") || lowered.contains("\"node_kind\":\"native\"") || lowered.contains("\"infinity_node\"") { return ApiDiscoveredNodeType::NativeNode; } fallback } fn extract_probe_hostname(probe: &HttpProbe) -> Option { if let Ok(json) = serde_json::from_str::(&probe.body) { let name = json .get("name") .and_then(Value::as_str) .or_else(|| { json.get("info") .and_then(|value| value.get("name")) .and_then(Value::as_str) }) .or_else(|| json.get("mdns").and_then(Value::as_str)); if let Some(name) = name { let trimmed = name.trim(); if !trimmed.is_empty() { return Some(trimmed.to_string()); } } } probe.headers.get("server").and_then(|value| { let trimmed = value.trim(); if trimmed.is_empty() { None } else { Some(trimmed.to_string()) } }) } fn send_stream_message( stream: &mut TcpStream, sequence: u64, generated_at_millis: u64, message: ApiStreamMessage, ) -> io::Result<()> { let payload = serde_json::to_string(&ApiStreamEnvelope { api_version: API_VERSION, sequence, generated_at_millis, message, }) .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?; write_text_frame(stream, &payload) } fn respond_json( stream: &mut TcpStream, status: u16, body: &T, ) -> io::Result<()> { let payload = serde_json::to_string_pretty(body) .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?; respond_text(stream, status, "application/json; charset=utf-8", &payload) } fn respond_error( stream: &mut TcpStream, status: u16, code: impl Into, message: impl Into, ) -> io::Result<()> { respond_json(stream, status, &ApiErrorResponse::new(code, message)) } fn respond_text( stream: &mut TcpStream, status: u16, content_type: &str, body: &str, ) -> io::Result<()> { let reason = match status { 200 => "OK", 400 => "Bad Request", 404 => "Not Found", 500 => "Internal Server Error", _ => "OK", }; let response = format!( "HTTP/1.1 {status} {reason}\r\nContent-Type: {content_type}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", body.as_bytes().len(), body ); stream.write_all(response.as_bytes()) } #[derive(Debug)] struct HttpRequest { method: String, path: String, headers: HashMap, body: Vec, } impl HttpRequest { fn header(&self, key: &str) -> Option<&str> { self.headers .get(&key.to_ascii_lowercase()) .map(|value| value.as_str()) } fn is_websocket(&self) -> bool { self.header("upgrade") .map(|value| value.eq_ignore_ascii_case("websocket")) .unwrap_or(false) } } fn read_request(stream: &mut TcpStream) -> io::Result { let mut buffer = Vec::new(); let mut temp = [0u8; 4096]; let mut header_end = None; let mut expected_len = None; loop { let read = stream.read(&mut temp)?; if read == 0 { break; } buffer.extend_from_slice(&temp[..read]); if header_end.is_none() { header_end = find_header_end(&buffer); if let Some(end) = header_end { let header_text = String::from_utf8_lossy(&buffer[..end]); expected_len = parse_content_length(&header_text); if expected_len == Some(0) || expected_len.is_none() { break; } } } if let (Some(end), Some(content_len)) = (header_end, expected_len) { if buffer.len() >= end + 4 + content_len { break; } } } let header_end = header_end .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "missing header end"))?; let header_text = String::from_utf8_lossy(&buffer[..header_end]); let mut lines = header_text.lines(); let request_line = lines .next() .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "missing request line"))?; let mut request_parts = request_line.split_whitespace(); let method = request_parts .next() .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "missing method"))? .to_string(); let path = request_parts .next() .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "missing path"))? .split('?') .next() .unwrap_or("/") .to_string(); let mut headers = HashMap::new(); for line in lines { if let Some((key, value)) = line.split_once(':') { headers.insert(key.trim().to_ascii_lowercase(), value.trim().to_string()); } } let body_start = header_end + 4; let body = buffer.get(body_start..).unwrap_or_default().to_vec(); Ok(HttpRequest { method, path, headers, body, }) } fn parse_content_length(header_text: &str) -> Option { header_text.lines().find_map(|line| { line.split_once(':').and_then(|(key, value)| { if key.trim().eq_ignore_ascii_case("content-length") { value.trim().parse::().ok() } else { None } }) }) } fn find_header_end(buffer: &[u8]) -> Option { buffer.windows(4).position(|window| window == b"\r\n\r\n") }