Skip to content

Commit

Permalink
improve: tcp read/write with function, not macro
Browse files Browse the repository at this point in the history
  • Loading branch information
luftaquila committed May 1, 2024
1 parent d1afd51 commit af1aeff
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 90 deletions.
2 changes: 0 additions & 2 deletions src/bin/transistor-client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@ fn main() -> Result<(), Error> {
let server = &args[1];

println!("[INF] transistor client startup! server: {}", server);

print_displays();

let mut client = Client::new(server)?;

client.start()?;

loop {}
Expand Down
84 changes: 59 additions & 25 deletions src/client.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use std::collections::HashMap;
use std::io::{stdout, Error, ErrorKind::*, Read, Write};
use std::{fs, net::TcpStream};
use std::{mem, u32};
use std::fs;
use std::io::{stdout, Error, ErrorKind::*, Write};
use std::mem;
use std::net::TcpStream;

use bincode::deserialize;
use display_info::DisplayInfo;
Expand Down Expand Up @@ -45,33 +46,64 @@ impl Client {

pub fn start(&mut self) -> Result<(), Error> {
// transmit cid to server
tcp_stream_write!(self.tcp, self.cid);
if let Err(e) = tcp_write(&mut self.tcp, self.cid) {
return Err(Error::new(
ConnectionRefused,
format!("handshake failed: {:?}", e),
));
};

/* receive display counts; 0 is unauthorized */
let mut buffer = vec![0u8; mem::size_of::<u32>()];
tcp_stream_read!(self.tcp, buffer);

if let Err(e) = tcp_read(&mut self.tcp, &mut buffer) {
return Err(Error::new(
ConnectionRefused,
format!("handshake failed: {:?}", e),
));
};

let disp_cnt: u32 = deserialize(&buffer).unwrap();

if disp_cnt < 1 {
return Err(Error::new(ConnectionRefused, "[ERR] authorization failed"));
return Err(Error::new(ConnectionRefused, "authorization failed"));
}

// receive server's current display configurations
tcp_stream_read_resize!(self.tcp, buffer);
if let Err(e) = tcp_read(&mut self.tcp, &mut buffer) {
return Err(Error::new(
ConnectionRefused,
format!("handshake failed: {:?}", e),
));
};

let server_disp_map: HashMap<Did, Display> = deserialize(&buffer).unwrap();
let server_disp: Vec<Display> = server_disp_map.values().cloned().collect();

// configure our displays' attach position and transmit to server
/* configure our displays' attach position and transmit to server */
self.set_display_position(server_disp);
tcp_stream_write!(self.tcp, self.displays);

if let Err(e) = tcp_write(&mut self.tcp, self.displays.clone()) {
return Err(Error::new(
ConnectionRefused,
format!("handshake failed: {:?}", e),
));
};

/* wait server ack */
tcp_stream_read!(self.tcp, buffer);
if let Err(e) = tcp_read(&mut self.tcp, &mut buffer) {
return Err(Error::new(
ConnectionRefused,
format!("handshake failed: {:?}", e),
));
};

if let HandshakeStatus::HandshakeErr = deserialize(&buffer).unwrap() {
return Err(Error::new(ConnectionRefused, "[ERR] request rejected"));
};

println!("[INF] connected!");

Ok(())
}

Expand Down Expand Up @@ -132,6 +164,22 @@ impl Client {
}
}

fn load_or_generate_cid() -> Result<Cid, Error> {
let cid_file = config_dir!("client").join("cid.txt");

if cid_file.exists() {
let txt = fs::read_to_string(cid_file)?;
Ok(txt.parse().expect("[ERR] failed to load cid"))
} else {
let cid: Cid = rand::random();

let mut file = fs::File::create(&cid_file)?;
file.write_all(cid.to_string().as_bytes())?;

Ok(cid)
}
}

fn prompt_display_position(displays: &mut Vec<Display>, server_conf: Vec<Display>) {
println!("\n########## display position setup ##########");
println!("[INF] current server displays:");
Expand Down Expand Up @@ -213,20 +261,6 @@ fn prompt_display_position(displays: &mut Vec<Display>, server_conf: Vec<Display
}
}
}
}

fn load_or_generate_cid() -> Result<Cid, Error> {
let cid_file = config_dir!("client").join("cid.txt");

if cid_file.exists() {
let txt = fs::read_to_string(cid_file)?;
Ok(txt.parse().expect("[ERR] failed to load cid"))
} else {
let cid: Cid = rand::random();

let mut file = fs::File::create(&cid_file)?;
file.write_all(cid.to_string().as_bytes())?;

Ok(cid)
}
// TODO: ask write config to file
}
6 changes: 4 additions & 2 deletions src/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ pub fn create_warpzones(a: &mut Vec<Display>, b: &mut Vec<Display>, eq: bool) ->
if disp.is_overlap(target.clone()) {
return Err(Error::new(
InvalidInput,
"[ERR] two displays are overlapping",
"displays are overlapping",
));
}

Expand Down Expand Up @@ -181,12 +181,14 @@ pub fn create_warpzones_hashmap(
if disp.is_overlap(target.clone()) {
return Err(Error::new(
InvalidInput,
"[ERR] two displays are overlapping",
"displays are overlapping",
));
}
}
}

// ! TODO: verify isolated display before add

// add warpzones
for disp in a.iter() {
for target in b.iter() {
Expand Down
53 changes: 39 additions & 14 deletions src/server.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::collections::HashMap;
use std::fs;
use std::io::{Error, ErrorKind::*, Read, Write};
use std::io::{Error, ErrorKind::*};
use std::mem;
use std::net::TcpListener;
use std::path::PathBuf;
Expand Down Expand Up @@ -86,43 +86,68 @@ fn handle_client(
mut displays: Arc<RwLock<HashMap<Did, Display>>>,
disp_ids: Arc<RwLock<AssignedDisplays>>,
authorized: Vec<Cid>,
) -> Result<(), Error> {
) {
let tcp = TcpListener::bind(("0.0.0.0", PORT)).expect("[ERR] TCP binding failed");

for mut stream in tcp.incoming().filter_map(Result::ok) {
let ip = stream.peer_addr().unwrap();

// read cid from remote client
let mut buffer = vec![0u8; mem::size_of::<Cid>()];
tcp_stream_read!(stream, buffer);

if let Err(e) = tcp_read(&mut stream, &mut buffer) {
eprintln!("[ERR] client {:?} handshake failed: {:?}", ip, e);
continue;
};

let cid = deserialize(&buffer).unwrap();

// reject not known client
// reject unknown client
if !authorized.contains(&cid) {
tcp_stream_write!(stream, 0);
if let Err(e) = tcp_write(&mut stream, 0) {
eprintln!("[ERR] client {:?} handshake failed: {:?}", ip, e);
continue;
};
}

// transmit display counts to client
tcp_stream_write!(stream, displays.read().unwrap().len() as u32);
if let Err(e) = tcp_write(&mut stream, displays.read().unwrap().len() as u32) {
eprintln!("[ERR] client {:?} handshake failed: {:?}", ip, e);
continue;
};

// transmit current displays
{
let disp = displays.read().unwrap();
tcp_stream_write!(stream, *disp);

if let Err(e) = tcp_write(&mut stream, disp.clone()) {
eprintln!("[ERR] client {:?} handshake failed: {:?}", ip, e);
continue;
};
}

// receive display attach request
tcp_stream_read_resize!(stream, buffer);
if let Err(e) = tcp_read(&mut stream, &mut buffer) {
eprintln!("[ERR] client {:?} handshake failed: {:?}", ip, e);
continue;
};

let mut client_disp: Vec<Display> = deserialize(&buffer).unwrap();

// update warpzones for new displays
let new = match create_warpzones_hashmap(&mut displays, &mut client_disp) {
Ok(new) => new,
Err(_) => {
return Err(Error::new(InvalidData, "[ERR] system display init failed"));
Err(e) => {
eprintln!("[ERR] invalid request from client {:?} : {:?}", ip, e);
continue;
}
};

// transmit ack
tcp_stream_write!(stream, HandshakeStatus::HandshakeOk);
if let Err(e) = tcp_write(&mut stream, HandshakeStatus::HandshakeOk as i32) {
eprintln!("[ERR] client {:?} handshake failed: {:?}", ip, e);
continue;
};

// add accepted client and display list
let client = Client {
Expand All @@ -133,9 +158,9 @@ fn handle_client(

clients.write().unwrap().insert(cid, client);
disp_ids.write().unwrap().client.extend(new);
}

Ok(())
println!("[INF] client {:?} connected!", ip);
}
}

fn authorized_clients(file: PathBuf) -> Result<Vec<Cid>, Error> {
Expand Down
72 changes: 25 additions & 47 deletions src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::io::stdin;
use std::io::{Error, ErrorKind::*};
use std::io::{stdin, Error, ErrorKind::*, Read, Write};
use std::net::TcpStream;

use display_info::DisplayInfo;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -48,6 +48,29 @@ pub fn stdin_char() -> Result<char, Error> {
}
}

pub fn tcp_read(stream: &mut TcpStream, buffer: &mut Vec<u8>) -> Result<usize, Error> {
let mut size = [0u8; 4];
stream.read_exact(&mut size)?;

let len = u32::from_be_bytes(size) as usize;
buffer.resize(len, 0);

stream.read_exact(buffer)?;

Ok(len)
}

pub fn tcp_write<T: Serialize>(stream: &mut TcpStream, data: T) -> Result<usize, Error> {
let encoded = bincode::serialize(&data).unwrap();
let len = encoded.len();
let size = (len as u32).to_be_bytes(); // force 4 byte data length

stream.write_all(&size)?;
stream.write_all(&encoded)?;

Ok(len)
}

#[macro_export]
macro_rules! config_dir {
($subpath: expr) => {{
Expand All @@ -61,48 +84,3 @@ macro_rules! config_dir {
}};
}

#[macro_export]
macro_rules! tcp_stream_read {
($stream:expr, $buffer:expr) => {{
let mut size = [0u8; 4];
$stream.read_exact(&mut size)?;

let len = u32::from_be_bytes(size) as usize;
$stream.read_exact(&mut $buffer[..len])?;

len
}};
}

#[macro_export]
macro_rules! tcp_stream_read_resize {
($stream:expr, $buffer:expr) => {{
let mut size = [0u8; 4];
$stream.read_exact(&mut size)?;

let len = u32::from_be_bytes(size) as usize;
$buffer.resize(len, 0);
$stream.read_exact(&mut $buffer)?;

len
}};
}

#[macro_export]
macro_rules! tcp_stream_write {
($stream:expr, $data:expr) => {
let encoded = bincode::serialize(&$data).unwrap();

// force 4 byte data length
let len = encoded.len() as u32;
let size = len.to_be_bytes();

if let Err(e) = $stream.write_all(&size) {
eprintln!("[ERR] TCP stream write failed: {}", e);
}

if let Err(e) = $stream.write_all(&encoded) {
eprintln!("[ERR] TCP stream write failed: {}", e);
}
};
}

0 comments on commit af1aeff

Please sign in to comment.