From ad7433cf2de2f4776f3b9e0414299439cc6056a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EC=98=A4=EB=B3=91=EC=A4=80?= Date: Thu, 2 May 2024 01:50:38 +0900 Subject: [PATCH] improve: tcp read/write with function, not macro --- src/bin/transistor-client.rs | 2 - src/client.rs | 84 +++++++++++++++++++++++++----------- src/display.rs | 8 ++-- src/server.rs | 53 +++++++++++++++++------ src/utils.rs | 72 +++++++++++-------------------- 5 files changed, 128 insertions(+), 91 deletions(-) diff --git a/src/bin/transistor-client.rs b/src/bin/transistor-client.rs index 8c0c8f0..b6eeb0a 100644 --- a/src/bin/transistor-client.rs +++ b/src/bin/transistor-client.rs @@ -17,11 +17,9 @@ fn main() -> Result<(), Error> { let server = &args[1]; println!("[INF] transistor client startup! server: {}", server); - print_displays(); let mut client = Client::new(server)?; - client.start()?; loop {} diff --git a/src/client.rs b/src/client.rs index f0ef9f2..2bc1215 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,7 +1,8 @@ use std::collections::HashMap; -use std::io::{stdout, Error, ErrorKind::*, Read, Write}; -use std::{fs, net::TcpStream}; -use std::{mem, u32}; +use std::fs; +use std::io::{stdout, Error, ErrorKind::*, Write}; +use std::mem; +use std::net::TcpStream; use bincode::deserialize; use display_info::DisplayInfo; @@ -45,33 +46,64 @@ impl Client { pub fn start(&mut self) -> Result<(), Error> { // transmit cid to server - tcp_stream_write!(self.tcp, self.cid); + if let Err(e) = tcp_write(&mut self.tcp, self.cid) { + return Err(Error::new( + ConnectionRefused, + format!("handshake failed: {:?}", e), + )); + }; /* receive display counts; 0 is unauthorized */ let mut buffer = vec![0u8; mem::size_of::()]; - tcp_stream_read!(self.tcp, buffer); + + if let Err(e) = tcp_read(&mut self.tcp, &mut buffer) { + return Err(Error::new( + ConnectionRefused, + format!("handshake failed: {:?}", e), + )); + }; + let disp_cnt: u32 = deserialize(&buffer).unwrap(); if disp_cnt < 1 { - return Err(Error::new(ConnectionRefused, "[ERR] authorization failed")); + return Err(Error::new(ConnectionRefused, "authorization failed")); } // receive server's current display configurations - tcp_stream_read_resize!(self.tcp, buffer); + if let Err(e) = tcp_read(&mut self.tcp, &mut buffer) { + return Err(Error::new( + ConnectionRefused, + format!("handshake failed: {:?}", e), + )); + }; + let server_disp_map: HashMap = deserialize(&buffer).unwrap(); let server_disp: Vec = server_disp_map.values().cloned().collect(); - // configure our displays' attach position and transmit to server + /* configure our displays' attach position and transmit to server */ self.set_display_position(server_disp); - tcp_stream_write!(self.tcp, self.displays); + + if let Err(e) = tcp_write(&mut self.tcp, self.displays.clone()) { + return Err(Error::new( + ConnectionRefused, + format!("handshake failed: {:?}", e), + )); + }; /* wait server ack */ - tcp_stream_read!(self.tcp, buffer); + if let Err(e) = tcp_read(&mut self.tcp, &mut buffer) { + return Err(Error::new( + ConnectionRefused, + format!("handshake failed: {:?}", e), + )); + }; if let HandshakeStatus::HandshakeErr = deserialize(&buffer).unwrap() { return Err(Error::new(ConnectionRefused, "[ERR] request rejected")); }; + println!("[INF] connected!"); + Ok(()) } @@ -132,6 +164,22 @@ impl Client { } } +fn load_or_generate_cid() -> Result { + let cid_file = config_dir!("client").join("cid.txt"); + + if cid_file.exists() { + let txt = fs::read_to_string(cid_file)?; + Ok(txt.parse().expect("[ERR] failed to load cid")) + } else { + let cid: Cid = rand::random(); + + let mut file = fs::File::create(&cid_file)?; + file.write_all(cid.to_string().as_bytes())?; + + Ok(cid) + } +} + fn prompt_display_position(displays: &mut Vec, server_conf: Vec) { println!("\n########## display position setup ##########"); println!("[INF] current server displays:"); @@ -213,20 +261,6 @@ fn prompt_display_position(displays: &mut Vec, server_conf: Vec Result { - let cid_file = config_dir!("client").join("cid.txt"); - - if cid_file.exists() { - let txt = fs::read_to_string(cid_file)?; - Ok(txt.parse().expect("[ERR] failed to load cid")) - } else { - let cid: Cid = rand::random(); - let mut file = fs::File::create(&cid_file)?; - file.write_all(cid.to_string().as_bytes())?; - - Ok(cid) - } + // TODO: ask write config to file } diff --git a/src/display.rs b/src/display.rs index 874ca0f..3e9b9eb 100644 --- a/src/display.rs +++ b/src/display.rs @@ -140,7 +140,7 @@ pub fn create_warpzones(a: &mut Vec, b: &mut Vec, eq: bool) -> if disp.is_overlap(target.clone()) { return Err(Error::new( InvalidInput, - "[ERR] two displays are overlapping", + "displays are overlapping", )); } @@ -175,15 +175,17 @@ pub fn create_warpzones_hashmap( let mut new = Vec::new(); - // check overlap first + // check overlap and isolated displays first for disp in a.iter() { for target in b.iter() { if disp.is_overlap(target.clone()) { return Err(Error::new( InvalidInput, - "[ERR] two displays are overlapping", + "displays are overlapping", )); } + + // TODO: verify isolated display } } diff --git a/src/server.rs b/src/server.rs index 21701f2..75e5778 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; use std::fs; -use std::io::{Error, ErrorKind::*, Read, Write}; +use std::io::{Error, ErrorKind::*}; use std::mem; use std::net::TcpListener; use std::path::PathBuf; @@ -86,43 +86,68 @@ fn handle_client( mut displays: Arc>>, disp_ids: Arc>, authorized: Vec, -) -> Result<(), Error> { +) { let tcp = TcpListener::bind(("0.0.0.0", PORT)).expect("[ERR] TCP binding failed"); for mut stream in tcp.incoming().filter_map(Result::ok) { + let ip = stream.peer_addr().unwrap(); + // read cid from remote client let mut buffer = vec![0u8; mem::size_of::()]; - tcp_stream_read!(stream, buffer); + + if let Err(e) = tcp_read(&mut stream, &mut buffer) { + eprintln!("[ERR] client {:?} handshake failed: {:?}", ip, e); + continue; + }; + let cid = deserialize(&buffer).unwrap(); - // reject not known client + // reject unknown client if !authorized.contains(&cid) { - tcp_stream_write!(stream, 0); + if let Err(e) = tcp_write(&mut stream, 0) { + eprintln!("[ERR] client {:?} handshake failed: {:?}", ip, e); + continue; + }; } // transmit display counts to client - tcp_stream_write!(stream, displays.read().unwrap().len() as u32); + if let Err(e) = tcp_write(&mut stream, displays.read().unwrap().len() as u32) { + eprintln!("[ERR] client {:?} handshake failed: {:?}", ip, e); + continue; + }; // transmit current displays { let disp = displays.read().unwrap(); - tcp_stream_write!(stream, *disp); + + if let Err(e) = tcp_write(&mut stream, disp.clone()) { + eprintln!("[ERR] client {:?} handshake failed: {:?}", ip, e); + continue; + }; } // receive display attach request - tcp_stream_read_resize!(stream, buffer); + if let Err(e) = tcp_read(&mut stream, &mut buffer) { + eprintln!("[ERR] client {:?} handshake failed: {:?}", ip, e); + continue; + }; + let mut client_disp: Vec = deserialize(&buffer).unwrap(); // update warpzones for new displays let new = match create_warpzones_hashmap(&mut displays, &mut client_disp) { Ok(new) => new, - Err(_) => { - return Err(Error::new(InvalidData, "[ERR] system display init failed")); + Err(e) => { + eprintln!("[ERR] invalid request from client {:?} : {:?}", ip, e); + continue; } }; - + // transmit ack - tcp_stream_write!(stream, HandshakeStatus::HandshakeOk); + if let Err(e) = tcp_write(&mut stream, HandshakeStatus::HandshakeOk as i32) { + eprintln!("[ERR] client {:?} handshake failed: {:?}", ip, e); + continue; + }; // add accepted client and display list let client = Client { @@ -133,9 +158,9 @@ fn handle_client( clients.write().unwrap().insert(cid, client); disp_ids.write().unwrap().client.extend(new); - } - Ok(()) + println!("[INF] client {:?} connected!", ip); + } } fn authorized_clients(file: PathBuf) -> Result, Error> { diff --git a/src/utils.rs b/src/utils.rs index bb812aa..591219b 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,5 +1,5 @@ -use std::io::stdin; -use std::io::{Error, ErrorKind::*}; +use std::io::{stdin, Error, ErrorKind::*, Read, Write}; +use std::net::TcpStream; use display_info::DisplayInfo; use serde::{Deserialize, Serialize}; @@ -48,6 +48,29 @@ pub fn stdin_char() -> Result { } } +pub fn tcp_read(stream: &mut TcpStream, buffer: &mut Vec) -> Result { + let mut size = [0u8; 4]; + stream.read_exact(&mut size)?; + + let len = u32::from_be_bytes(size) as usize; + buffer.resize(len, 0); + + stream.read_exact(buffer)?; + + Ok(len) +} + +pub fn tcp_write(stream: &mut TcpStream, data: T) -> Result { + let encoded = bincode::serialize(&data).unwrap(); + let len = encoded.len(); + let size = (len as u32).to_be_bytes(); // force 4 byte data length + + stream.write_all(&size)?; + stream.write_all(&encoded)?; + + Ok(len) +} + #[macro_export] macro_rules! config_dir { ($subpath: expr) => {{ @@ -61,48 +84,3 @@ macro_rules! config_dir { }}; } -#[macro_export] -macro_rules! tcp_stream_read { - ($stream:expr, $buffer:expr) => {{ - let mut size = [0u8; 4]; - $stream.read_exact(&mut size)?; - - let len = u32::from_be_bytes(size) as usize; - $stream.read_exact(&mut $buffer[..len])?; - - len - }}; -} - -#[macro_export] -macro_rules! tcp_stream_read_resize { - ($stream:expr, $buffer:expr) => {{ - let mut size = [0u8; 4]; - $stream.read_exact(&mut size)?; - - let len = u32::from_be_bytes(size) as usize; - $buffer.resize(len, 0); - $stream.read_exact(&mut $buffer)?; - - len - }}; -} - -#[macro_export] -macro_rules! tcp_stream_write { - ($stream:expr, $data:expr) => { - let encoded = bincode::serialize(&$data).unwrap(); - - // force 4 byte data length - let len = encoded.len() as u32; - let size = len.to_be_bytes(); - - if let Err(e) = $stream.write_all(&size) { - eprintln!("[ERR] TCP stream write failed: {}", e); - } - - if let Err(e) = $stream.write_all(&encoded) { - eprintln!("[ERR] TCP stream write failed: {}", e); - } - }; -}