773 lines
24 KiB
Rust
773 lines
24 KiB
Rust
use crate::dto::{
|
|
ApiCatalogResponse, ApiCommandRequest, ApiCommandResponse, ApiDiscoveredNodeType,
|
|
ApiDiscoveryResult, ApiDiscoveryScanRequest, ApiDiscoveryScanResponse, ApiErrorResponse,
|
|
ApiGroupListResponse, ApiPresetListResponse, ApiPreviewResponse, ApiSnapshotResponse,
|
|
ApiStateResponse, ApiStateSnapshot, ApiStreamEnvelope, ApiStreamMessage, API_VERSION,
|
|
};
|
|
use crate::websocket::{websocket_accept_value, write_text_frame};
|
|
use infinity_host::HostApiPort;
|
|
use serde_json::Value;
|
|
use std::collections::HashMap;
|
|
use std::io::{self, Read, Write};
|
|
use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpListener, TcpStream};
|
|
use std::sync::{
|
|
atomic::{AtomicBool, Ordering},
|
|
mpsc, Arc, Mutex,
|
|
};
|
|
use std::thread::{self, JoinHandle};
|
|
use std::time::{Duration, Instant};
|
|
|
|
pub struct HostApiServer {
|
|
local_addr: SocketAddr,
|
|
shutdown: Arc<AtomicBool>,
|
|
accept_thread: Option<JoinHandle<()>>,
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
struct ApiRequestError {
|
|
status: u16,
|
|
code: String,
|
|
message: String,
|
|
}
|
|
|
|
impl HostApiServer {
|
|
pub fn bind(bind: &str, service: Arc<dyn HostApiPort>) -> io::Result<Self> {
|
|
let listener = TcpListener::bind(bind)?;
|
|
listener.set_nonblocking(true)?;
|
|
let local_addr = listener.local_addr()?;
|
|
let shutdown = Arc::new(AtomicBool::new(false));
|
|
let thread_shutdown = Arc::clone(&shutdown);
|
|
let accept_thread = thread::spawn(move || accept_loop(listener, service, thread_shutdown));
|
|
Ok(Self {
|
|
local_addr,
|
|
shutdown,
|
|
accept_thread: Some(accept_thread),
|
|
})
|
|
}
|
|
|
|
pub fn local_addr(&self) -> SocketAddr {
|
|
self.local_addr
|
|
}
|
|
|
|
pub fn shutdown(mut self) {
|
|
self.shutdown.store(true, Ordering::SeqCst);
|
|
if let Some(handle) = self.accept_thread.take() {
|
|
let _ = handle.join();
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Drop for HostApiServer {
|
|
fn drop(&mut self) {
|
|
self.shutdown.store(true, Ordering::SeqCst);
|
|
if let Some(handle) = self.accept_thread.take() {
|
|
let _ = handle.join();
|
|
}
|
|
}
|
|
}
|
|
|
|
fn accept_loop(listener: TcpListener, service: Arc<dyn HostApiPort>, shutdown: Arc<AtomicBool>) {
|
|
while !shutdown.load(Ordering::SeqCst) {
|
|
match listener.accept() {
|
|
Ok((stream, _)) => {
|
|
let service = Arc::clone(&service);
|
|
thread::spawn(move || {
|
|
let _ = handle_connection(stream, service);
|
|
});
|
|
}
|
|
Err(error) if error.kind() == io::ErrorKind::WouldBlock => {
|
|
thread::sleep(Duration::from_millis(25));
|
|
}
|
|
Err(_) => break,
|
|
}
|
|
}
|
|
}
|
|
|
|
fn handle_connection(mut stream: TcpStream, service: Arc<dyn HostApiPort>) -> io::Result<()> {
|
|
stream.set_read_timeout(Some(Duration::from_secs(2)))?;
|
|
let request = read_request(&mut stream)?;
|
|
if request.path == "/api/v1/stream" && request.is_websocket() {
|
|
return handle_websocket(stream, request, service);
|
|
}
|
|
|
|
match (request.method.as_str(), request.path.as_str()) {
|
|
("GET", "/api/v1/snapshot") => {
|
|
let snapshot = service.snapshot();
|
|
respond_json(
|
|
&mut stream,
|
|
200,
|
|
&ApiSnapshotResponse::from_snapshot(&snapshot),
|
|
)
|
|
}
|
|
("GET", "/api/v1/state") => {
|
|
let snapshot = service.snapshot();
|
|
respond_json(
|
|
&mut stream,
|
|
200,
|
|
&ApiStateResponse::from_snapshot(&snapshot),
|
|
)
|
|
}
|
|
("GET", "/api/v1/preview") => {
|
|
let snapshot = service.snapshot();
|
|
respond_json(
|
|
&mut stream,
|
|
200,
|
|
&ApiPreviewResponse::from_snapshot(&snapshot),
|
|
)
|
|
}
|
|
("GET", "/api/v1/catalog") => {
|
|
let snapshot = service.snapshot();
|
|
respond_json(
|
|
&mut stream,
|
|
200,
|
|
&ApiCatalogResponse::from_snapshot(&snapshot),
|
|
)
|
|
}
|
|
("GET", "/api/v1/presets") => {
|
|
let snapshot = service.snapshot();
|
|
respond_json(
|
|
&mut stream,
|
|
200,
|
|
&ApiPresetListResponse::from_snapshot(&snapshot),
|
|
)
|
|
}
|
|
("GET", "/api/v1/groups") => {
|
|
let snapshot = service.snapshot();
|
|
respond_json(
|
|
&mut stream,
|
|
200,
|
|
&ApiGroupListResponse::from_snapshot(&snapshot),
|
|
)
|
|
}
|
|
("POST", "/api/v1/command") => match handle_command_post(&mut stream, request, service) {
|
|
Ok(()) => Ok(()),
|
|
Err(error) => respond_error(&mut stream, error.status, error.code, error.message),
|
|
},
|
|
("POST", "/api/v1/discovery/scan") => {
|
|
match handle_discovery_scan_post(&mut stream, request) {
|
|
Ok(()) => Ok(()),
|
|
Err(error) => respond_error(&mut stream, error.status, error.code, error.message),
|
|
}
|
|
}
|
|
("GET", "/") => respond_text(
|
|
&mut stream,
|
|
200,
|
|
"text/html; charset=utf-8",
|
|
include_str!("../../../web/v1/index.html"),
|
|
),
|
|
("GET", "/index.html") => respond_text(
|
|
&mut stream,
|
|
200,
|
|
"text/html; charset=utf-8",
|
|
include_str!("../../../web/v1/index.html"),
|
|
),
|
|
("GET", "/technical") => respond_text(
|
|
&mut stream,
|
|
200,
|
|
"text/html; charset=utf-8",
|
|
include_str!("../../../web/v1/technical.html"),
|
|
),
|
|
("GET", "/technical.html") => respond_text(
|
|
&mut stream,
|
|
200,
|
|
"text/html; charset=utf-8",
|
|
include_str!("../../../web/v1/technical.html"),
|
|
),
|
|
("GET", "/app.js") => respond_text(
|
|
&mut stream,
|
|
200,
|
|
"application/javascript; charset=utf-8",
|
|
include_str!("../../../web/v1/app.js"),
|
|
),
|
|
("GET", "/technical.js") => respond_text(
|
|
&mut stream,
|
|
200,
|
|
"application/javascript; charset=utf-8",
|
|
include_str!("../../../web/v1/technical.js"),
|
|
),
|
|
("GET", "/styles.css") => respond_text(
|
|
&mut stream,
|
|
200,
|
|
"text/css; charset=utf-8",
|
|
include_str!("../../../web/v1/styles.css"),
|
|
),
|
|
_ => respond_text(
|
|
&mut stream,
|
|
404,
|
|
"application/json; charset=utf-8",
|
|
&serde_json::to_string_pretty(&ApiErrorResponse::new(
|
|
"not_found",
|
|
format!(
|
|
"no route registered for {} {}",
|
|
request.method, request.path
|
|
),
|
|
))
|
|
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?,
|
|
),
|
|
}
|
|
}
|
|
|
|
fn handle_command_post(
|
|
stream: &mut TcpStream,
|
|
request: HttpRequest,
|
|
service: Arc<dyn HostApiPort>,
|
|
) -> Result<(), ApiRequestError> {
|
|
let parsed = serde_json::from_slice::<ApiCommandRequest>(&request.body).map_err(|error| {
|
|
ApiRequestError {
|
|
status: 400,
|
|
code: "invalid_request_json".to_string(),
|
|
message: format!("command request body could not be parsed: {error}"),
|
|
}
|
|
})?;
|
|
let request_id = parsed.request_id.clone();
|
|
let command_type = parsed.command.kind_label().to_string();
|
|
let command = parsed
|
|
.into_host_command()
|
|
.map_err(|error| ApiRequestError {
|
|
status: 400,
|
|
code: "invalid_command".to_string(),
|
|
message: error,
|
|
})?;
|
|
let outcome = service
|
|
.send_command(command)
|
|
.map_err(|error| ApiRequestError {
|
|
status: 400,
|
|
code: error.code,
|
|
message: error.message,
|
|
})?;
|
|
respond_json(
|
|
stream,
|
|
200,
|
|
&ApiCommandResponse {
|
|
api_version: API_VERSION,
|
|
accepted: true,
|
|
request_id,
|
|
generated_at_millis: outcome.generated_at_millis,
|
|
command_type,
|
|
summary: outcome.summary,
|
|
},
|
|
)
|
|
.map_err(|error| ApiRequestError {
|
|
status: 500,
|
|
code: "response_write_failed".to_string(),
|
|
message: error.to_string(),
|
|
})
|
|
}
|
|
|
|
fn handle_websocket(
|
|
mut stream: TcpStream,
|
|
request: HttpRequest,
|
|
service: Arc<dyn HostApiPort>,
|
|
) -> io::Result<()> {
|
|
let Some(key) = request.header("sec-websocket-key") else {
|
|
return respond_error(
|
|
&mut stream,
|
|
400,
|
|
"missing_websocket_key",
|
|
"websocket upgrade requires sec-websocket-key",
|
|
);
|
|
};
|
|
|
|
let accept = websocket_accept_value(key);
|
|
let response = format!(
|
|
"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: {accept}\r\n\r\n"
|
|
);
|
|
stream.write_all(response.as_bytes())?;
|
|
|
|
let mut sequence = 1u64;
|
|
let mut last_event_millis = None::<u64>;
|
|
let mut last_event_signatures = Vec::<(Option<String>, String)>::new();
|
|
let mut last_streamed_preview = None::<crate::dto::ApiPreviewSnapshot>;
|
|
loop {
|
|
let snapshot = service.snapshot();
|
|
send_stream_message(
|
|
&mut stream,
|
|
sequence,
|
|
snapshot.generated_at_millis,
|
|
ApiStreamMessage::Snapshot(ApiStateSnapshot::from_snapshot(&snapshot)),
|
|
)?;
|
|
sequence += 1;
|
|
|
|
let preview_payload = crate::dto::ApiPreviewSnapshot::from_snapshot(&snapshot);
|
|
if last_streamed_preview
|
|
.as_ref()
|
|
.map(|previous| previous != &preview_payload)
|
|
.unwrap_or(true)
|
|
{
|
|
send_stream_message(
|
|
&mut stream,
|
|
sequence,
|
|
snapshot.generated_at_millis,
|
|
ApiStreamMessage::Preview(preview_payload.clone()),
|
|
)?;
|
|
sequence += 1;
|
|
last_streamed_preview = Some(preview_payload);
|
|
}
|
|
|
|
let mut new_events = snapshot
|
|
.recent_events
|
|
.iter()
|
|
.filter(|event| match last_event_millis {
|
|
None => true,
|
|
Some(last_millis) if event.at_millis > last_millis => true,
|
|
Some(last_millis) if event.at_millis == last_millis => !last_event_signatures
|
|
.iter()
|
|
.any(|signature| signature.0 == event.code && signature.1 == event.message),
|
|
Some(_) => false,
|
|
})
|
|
.cloned()
|
|
.collect::<Vec<_>>();
|
|
new_events.sort_by_key(|event| event.at_millis);
|
|
for event in new_events {
|
|
let event_millis = event.at_millis;
|
|
let current_signature = (event.code.clone(), event.message.clone());
|
|
send_stream_message(
|
|
&mut stream,
|
|
sequence,
|
|
event_millis,
|
|
ApiStreamMessage::Event(event.into()),
|
|
)?;
|
|
sequence += 1;
|
|
match last_event_millis {
|
|
Some(last_millis) if last_millis == event_millis => {
|
|
last_event_signatures.push(current_signature);
|
|
}
|
|
_ => {
|
|
last_event_millis = Some(event_millis);
|
|
last_event_signatures = vec![current_signature];
|
|
}
|
|
}
|
|
}
|
|
|
|
thread::sleep(Duration::from_millis(250));
|
|
}
|
|
}
|
|
|
|
fn handle_discovery_scan_post(
|
|
stream: &mut TcpStream,
|
|
request: HttpRequest,
|
|
) -> Result<(), ApiRequestError> {
|
|
let parsed =
|
|
serde_json::from_slice::<ApiDiscoveryScanRequest>(&request.body).map_err(|error| {
|
|
ApiRequestError {
|
|
status: 400,
|
|
code: "invalid_request_json".to_string(),
|
|
message: format!("discovery request body could not be parsed: {error}"),
|
|
}
|
|
})?;
|
|
let targets = parse_subnet_targets(&parsed.subnet).map_err(|message| ApiRequestError {
|
|
status: 400,
|
|
code: "invalid_subnet_cidr".to_string(),
|
|
message,
|
|
})?;
|
|
|
|
let started_at = Instant::now();
|
|
let mut results = scan_subnet_targets(&targets);
|
|
results.sort_by_key(|result| {
|
|
result
|
|
.ip
|
|
.parse::<Ipv4Addr>()
|
|
.map(u32::from)
|
|
.unwrap_or_default()
|
|
});
|
|
let reachable_hosts = results.iter().filter(|result| result.reachable).count();
|
|
|
|
respond_json(
|
|
stream,
|
|
200,
|
|
&ApiDiscoveryScanResponse {
|
|
api_version: API_VERSION,
|
|
subnet: parsed.subnet.trim().to_string(),
|
|
scanned_hosts: targets.len(),
|
|
reachable_hosts,
|
|
results,
|
|
},
|
|
)
|
|
.map_err(|error| ApiRequestError {
|
|
status: 500,
|
|
code: "response_write_failed".to_string(),
|
|
message: format!(
|
|
"discovery response could not be written after {} ms: {error}",
|
|
started_at.elapsed().as_millis()
|
|
),
|
|
})
|
|
}
|
|
|
|
fn parse_subnet_targets(raw_subnet: &str) -> Result<Vec<Ipv4Addr>, String> {
|
|
const MAX_SCAN_HOSTS: u64 = 1024;
|
|
|
|
let subnet = raw_subnet.trim();
|
|
let (address, prefix) = subnet
|
|
.split_once('/')
|
|
.ok_or_else(|| format!("subnet '{subnet}' must be in CIDR form, e.g. 192.168.40.0/24"))?;
|
|
let ip = address
|
|
.trim()
|
|
.parse::<Ipv4Addr>()
|
|
.map_err(|_| format!("subnet '{subnet}' contains an invalid IPv4 address"))?;
|
|
let prefix = prefix
|
|
.trim()
|
|
.parse::<u8>()
|
|
.map_err(|_| format!("subnet '{subnet}' contains an invalid CIDR prefix"))?;
|
|
if prefix > 32 {
|
|
return Err(format!(
|
|
"subnet '{subnet}' has prefix {prefix}, expected 0..=32"
|
|
));
|
|
}
|
|
|
|
let host_span = 1u64 << (32u8.saturating_sub(prefix));
|
|
if host_span > MAX_SCAN_HOSTS {
|
|
return Err(format!(
|
|
"subnet '{subnet}' spans {host_span} addresses, limit is {MAX_SCAN_HOSTS}"
|
|
));
|
|
}
|
|
|
|
let ip_u32 = u32::from(ip);
|
|
let mask = if prefix == 0 {
|
|
0
|
|
} else {
|
|
u32::MAX << (32 - u32::from(prefix))
|
|
};
|
|
let network = ip_u32 & mask;
|
|
let broadcast = network | !mask;
|
|
let (start, end) = if prefix >= 31 {
|
|
(network, broadcast)
|
|
} else {
|
|
(network.saturating_add(1), broadcast.saturating_sub(1))
|
|
};
|
|
|
|
if start > end {
|
|
return Err(format!("subnet '{subnet}' has no scanable host addresses"));
|
|
}
|
|
|
|
Ok((start..=end).map(Ipv4Addr::from).collect())
|
|
}
|
|
|
|
fn scan_subnet_targets(targets: &[Ipv4Addr]) -> Vec<ApiDiscoveryResult> {
|
|
if targets.is_empty() {
|
|
return Vec::new();
|
|
}
|
|
|
|
let worker_count = usize::min(32, targets.len().max(1));
|
|
let (job_sender, job_receiver) = mpsc::channel::<Ipv4Addr>();
|
|
let job_receiver = Arc::new(Mutex::new(job_receiver));
|
|
let (result_sender, result_receiver) = mpsc::channel::<ApiDiscoveryResult>();
|
|
let mut handles = Vec::with_capacity(worker_count);
|
|
|
|
for _ in 0..worker_count {
|
|
let receiver = Arc::clone(&job_receiver);
|
|
let sender = result_sender.clone();
|
|
handles.push(thread::spawn(move || loop {
|
|
let next_job = {
|
|
let guard = receiver.lock();
|
|
match guard {
|
|
Ok(receiver) => receiver.recv().ok(),
|
|
Err(_) => None,
|
|
}
|
|
};
|
|
let Some(ip) = next_job else {
|
|
break;
|
|
};
|
|
let _ = sender.send(probe_ip(ip));
|
|
}));
|
|
}
|
|
drop(result_sender);
|
|
|
|
for ip in targets {
|
|
let _ = job_sender.send(*ip);
|
|
}
|
|
drop(job_sender);
|
|
|
|
let mut results = Vec::with_capacity(targets.len());
|
|
for _ in 0..targets.len() {
|
|
if let Ok(result) = result_receiver.recv() {
|
|
results.push(result);
|
|
}
|
|
}
|
|
|
|
for handle in handles {
|
|
let _ = handle.join();
|
|
}
|
|
results
|
|
}
|
|
|
|
fn probe_ip(ip: Ipv4Addr) -> ApiDiscoveryResult {
|
|
let mut reachable = false;
|
|
let mut detected_type = ApiDiscoveredNodeType::Unknown;
|
|
let mut hostname = None;
|
|
|
|
if let Some(info_probe) = probe_http_endpoint(ip, 80, "/json/info") {
|
|
reachable = true;
|
|
detected_type = detect_node_type(&info_probe.body, detected_type);
|
|
hostname = extract_probe_hostname(&info_probe);
|
|
} else if can_connect(ip, 80) {
|
|
reachable = true;
|
|
}
|
|
|
|
if !reachable && can_connect(ip, 81) {
|
|
reachable = true;
|
|
}
|
|
|
|
if detected_type == ApiDiscoveredNodeType::Unknown {
|
|
if let Some(node_probe) = probe_http_endpoint(ip, 80, "/api/v1/node/info") {
|
|
reachable = true;
|
|
detected_type = detect_node_type(&node_probe.body, detected_type);
|
|
if hostname.is_none() {
|
|
hostname = extract_probe_hostname(&node_probe);
|
|
}
|
|
} else if let Some(state_probe) = probe_http_endpoint(ip, 80, "/api/v1/state") {
|
|
reachable = true;
|
|
detected_type = detect_node_type(&state_probe.body, detected_type);
|
|
if hostname.is_none() {
|
|
hostname = extract_probe_hostname(&state_probe);
|
|
}
|
|
}
|
|
}
|
|
|
|
ApiDiscoveryResult {
|
|
ip: ip.to_string(),
|
|
reachable,
|
|
detected_type,
|
|
hostname,
|
|
}
|
|
}
|
|
|
|
fn can_connect(ip: Ipv4Addr, port: u16) -> bool {
|
|
let address = SocketAddr::new(IpAddr::V4(ip), port);
|
|
TcpStream::connect_timeout(&address, Duration::from_millis(120)).is_ok()
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
struct HttpProbe {
|
|
headers: HashMap<String, String>,
|
|
body: String,
|
|
}
|
|
|
|
fn probe_http_endpoint(ip: Ipv4Addr, port: u16, path: &str) -> Option<HttpProbe> {
|
|
let address = SocketAddr::new(IpAddr::V4(ip), port);
|
|
let mut stream = TcpStream::connect_timeout(&address, Duration::from_millis(120)).ok()?;
|
|
let _ = stream.set_read_timeout(Some(Duration::from_millis(180)));
|
|
let _ = stream.set_write_timeout(Some(Duration::from_millis(120)));
|
|
let request = format!(
|
|
"GET {path} HTTP/1.1\r\nHost: {ip}\r\nConnection: close\r\nAccept: application/json\r\n\r\n"
|
|
);
|
|
stream.write_all(request.as_bytes()).ok()?;
|
|
|
|
let mut raw = Vec::new();
|
|
stream.read_to_end(&mut raw).ok()?;
|
|
let header_end = find_header_end(&raw)?;
|
|
let header_text = String::from_utf8_lossy(&raw[..header_end]);
|
|
let headers = header_text
|
|
.lines()
|
|
.skip(1)
|
|
.filter_map(|line| line.split_once(':'))
|
|
.map(|(key, value)| (key.trim().to_ascii_lowercase(), value.trim().to_string()))
|
|
.collect::<HashMap<_, _>>();
|
|
let body = String::from_utf8_lossy(raw.get(header_end + 4..).unwrap_or_default()).to_string();
|
|
Some(HttpProbe { headers, body })
|
|
}
|
|
|
|
fn detect_node_type(body: &str, fallback: ApiDiscoveredNodeType) -> ApiDiscoveredNodeType {
|
|
let lowered = body.to_ascii_lowercase();
|
|
if lowered.contains("\"wled\"")
|
|
|| lowered.contains("\"brand\":\"wled\"")
|
|
|| lowered.contains("\"product\":\"wled\"")
|
|
{
|
|
return ApiDiscoveredNodeType::Wled;
|
|
}
|
|
if lowered.contains("\"native_node\"")
|
|
|| lowered.contains("\"node_kind\":\"native\"")
|
|
|| lowered.contains("\"infinity_node\"")
|
|
{
|
|
return ApiDiscoveredNodeType::NativeNode;
|
|
}
|
|
fallback
|
|
}
|
|
|
|
fn extract_probe_hostname(probe: &HttpProbe) -> Option<String> {
|
|
if let Ok(json) = serde_json::from_str::<Value>(&probe.body) {
|
|
let name = json
|
|
.get("name")
|
|
.and_then(Value::as_str)
|
|
.or_else(|| {
|
|
json.get("info")
|
|
.and_then(|value| value.get("name"))
|
|
.and_then(Value::as_str)
|
|
})
|
|
.or_else(|| json.get("mdns").and_then(Value::as_str));
|
|
if let Some(name) = name {
|
|
let trimmed = name.trim();
|
|
if !trimmed.is_empty() {
|
|
return Some(trimmed.to_string());
|
|
}
|
|
}
|
|
}
|
|
|
|
probe.headers.get("server").and_then(|value| {
|
|
let trimmed = value.trim();
|
|
if trimmed.is_empty() {
|
|
None
|
|
} else {
|
|
Some(trimmed.to_string())
|
|
}
|
|
})
|
|
}
|
|
|
|
fn send_stream_message(
|
|
stream: &mut TcpStream,
|
|
sequence: u64,
|
|
generated_at_millis: u64,
|
|
message: ApiStreamMessage,
|
|
) -> io::Result<()> {
|
|
let payload = serde_json::to_string(&ApiStreamEnvelope {
|
|
api_version: API_VERSION,
|
|
sequence,
|
|
generated_at_millis,
|
|
message,
|
|
})
|
|
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
|
|
write_text_frame(stream, &payload)
|
|
}
|
|
|
|
fn respond_json<T: serde::Serialize>(
|
|
stream: &mut TcpStream,
|
|
status: u16,
|
|
body: &T,
|
|
) -> io::Result<()> {
|
|
let payload = serde_json::to_string_pretty(body)
|
|
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
|
|
respond_text(stream, status, "application/json; charset=utf-8", &payload)
|
|
}
|
|
|
|
fn respond_error(
|
|
stream: &mut TcpStream,
|
|
status: u16,
|
|
code: impl Into<String>,
|
|
message: impl Into<String>,
|
|
) -> io::Result<()> {
|
|
respond_json(stream, status, &ApiErrorResponse::new(code, message))
|
|
}
|
|
|
|
fn respond_text(
|
|
stream: &mut TcpStream,
|
|
status: u16,
|
|
content_type: &str,
|
|
body: &str,
|
|
) -> io::Result<()> {
|
|
let reason = match status {
|
|
200 => "OK",
|
|
400 => "Bad Request",
|
|
404 => "Not Found",
|
|
500 => "Internal Server Error",
|
|
_ => "OK",
|
|
};
|
|
let response = format!(
|
|
"HTTP/1.1 {status} {reason}\r\nContent-Type: {content_type}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
|
|
body.as_bytes().len(),
|
|
body
|
|
);
|
|
stream.write_all(response.as_bytes())
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
struct HttpRequest {
|
|
method: String,
|
|
path: String,
|
|
headers: HashMap<String, String>,
|
|
body: Vec<u8>,
|
|
}
|
|
|
|
impl HttpRequest {
|
|
fn header(&self, key: &str) -> Option<&str> {
|
|
self.headers
|
|
.get(&key.to_ascii_lowercase())
|
|
.map(|value| value.as_str())
|
|
}
|
|
|
|
fn is_websocket(&self) -> bool {
|
|
self.header("upgrade")
|
|
.map(|value| value.eq_ignore_ascii_case("websocket"))
|
|
.unwrap_or(false)
|
|
}
|
|
}
|
|
|
|
fn read_request(stream: &mut TcpStream) -> io::Result<HttpRequest> {
|
|
let mut buffer = Vec::new();
|
|
let mut temp = [0u8; 4096];
|
|
let mut header_end = None;
|
|
let mut expected_len = None;
|
|
|
|
loop {
|
|
let read = stream.read(&mut temp)?;
|
|
if read == 0 {
|
|
break;
|
|
}
|
|
buffer.extend_from_slice(&temp[..read]);
|
|
if header_end.is_none() {
|
|
header_end = find_header_end(&buffer);
|
|
if let Some(end) = header_end {
|
|
let header_text = String::from_utf8_lossy(&buffer[..end]);
|
|
expected_len = parse_content_length(&header_text);
|
|
if expected_len == Some(0) || expected_len.is_none() {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
if let (Some(end), Some(content_len)) = (header_end, expected_len) {
|
|
if buffer.len() >= end + 4 + content_len {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
let header_end = header_end
|
|
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "missing header end"))?;
|
|
let header_text = String::from_utf8_lossy(&buffer[..header_end]);
|
|
let mut lines = header_text.lines();
|
|
let request_line = lines
|
|
.next()
|
|
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "missing request line"))?;
|
|
let mut request_parts = request_line.split_whitespace();
|
|
let method = request_parts
|
|
.next()
|
|
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "missing method"))?
|
|
.to_string();
|
|
let path = request_parts
|
|
.next()
|
|
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "missing path"))?
|
|
.split('?')
|
|
.next()
|
|
.unwrap_or("/")
|
|
.to_string();
|
|
let mut headers = HashMap::new();
|
|
for line in lines {
|
|
if let Some((key, value)) = line.split_once(':') {
|
|
headers.insert(key.trim().to_ascii_lowercase(), value.trim().to_string());
|
|
}
|
|
}
|
|
let body_start = header_end + 4;
|
|
let body = buffer.get(body_start..).unwrap_or_default().to_vec();
|
|
Ok(HttpRequest {
|
|
method,
|
|
path,
|
|
headers,
|
|
body,
|
|
})
|
|
}
|
|
|
|
fn parse_content_length(header_text: &str) -> Option<usize> {
|
|
header_text.lines().find_map(|line| {
|
|
line.split_once(':').and_then(|(key, value)| {
|
|
if key.trim().eq_ignore_ascii_case("content-length") {
|
|
value.trim().parse::<usize>().ok()
|
|
} else {
|
|
None
|
|
}
|
|
})
|
|
})
|
|
}
|
|
|
|
fn find_header_end(buffer: &[u8]) -> Option<usize> {
|
|
buffer.windows(4).position(|window| window == b"\r\n\r\n")
|
|
}
|