diff --git a/.all-contributorsrc b/.all-contributorsrc index 1db94a0b..66444bad 100644 --- a/.all-contributorsrc +++ b/.all-contributorsrc @@ -175,6 +175,150 @@ "contributions": [ "code" ] + }, + { + "login": "Nurrl", + "name": "Maya the bee", + "avatar_url": "https://avatars.githubusercontent.com/u/15341887?v=4", + "profile": "https://github.com/Nurrl", + "contributions": [ + "code" + ] + }, + { + "login": "mmirate", + "name": "Milo Mirate", + "avatar_url": "https://avatars.githubusercontent.com/u/992859?v=4", + "profile": "https://github.com/mmirate", + "contributions": [ + "code" + ] + }, + { + "login": "george-hopkins", + "name": "George Hopkins", + "avatar_url": "https://avatars.githubusercontent.com/u/552590?v=4", + "profile": "https://github.com/george-hopkins", + "contributions": [ + "code" + ] + }, + { + "login": "akeamc", + "name": "Åke Amcoff", + "avatar_url": "https://avatars.githubusercontent.com/u/17624114?v=4", + "profile": "https://amcoff.net/", + "contributions": [ + "code" + ] + }, + { + "login": "bho01", + "name": "Brendon Ho", + "avatar_url": "https://avatars.githubusercontent.com/u/12106620?v=4", + "profile": "http://brendonho.com", + "contributions": [ + "code" + ] + }, + { + "login": "samuela", + "name": "Samuel Ainsworth", + "avatar_url": "https://avatars.githubusercontent.com/u/226872?v=4", + "profile": "http://samlikes.pizza/", + "contributions": [ + "code" + ] + }, + { + "login": "sherlock-holo", + "name": "Sherlock Holo", + "avatar_url": "https://avatars.githubusercontent.com/u/10096425?v=4", + "profile": "https://github.com/Sherlock-Holo", + "contributions": [ + "code" + ] + }, + { + "login": "ricott1", + "name": "Alessandro Ricottone", + "avatar_url": "https://avatars.githubusercontent.com/u/16502243?v=4", + "profile": "https://github.com/ricott1", + "contributions": [ + "code" + ] + }, + { + "login": "T0b1-iOS", + "name": "T0b1-iOS", + "avatar_url": "https://avatars.githubusercontent.com/u/15174814?v=4", + "profile": "https://github.com/T0b1-iOS", + "contributions": [ + "code" + ] + }, + { + "login": "shoaibmerchant", + "name": "Shoaib Merchant", + "avatar_url": "https://avatars.githubusercontent.com/u/4598631?v=4", + "profile": "https://mecha.so", + "contributions": [ + "code" + ] + }, + { + "login": "gleason-m", + "name": "Michael Gleason", + "avatar_url": "https://avatars.githubusercontent.com/u/86493344?v=4", + "profile": "https://github.com/gleason-m", + "contributions": [ + "code" + ] + }, + { + "login": "elegaanz", + "name": "Ana Gelez", + "avatar_url": "https://avatars.githubusercontent.com/u/16254623?v=4", + "profile": "https://ana.gelez.xyz", + "contributions": [ + "code" + ] + }, + { + "login": "tomknig", + "name": "Tom König", + "avatar_url": "https://avatars.githubusercontent.com/u/3586316?v=4", + "profile": "https://github.com/tomknig", + "contributions": [ + "code" + ] + }, + { + "login": "Barre", + "name": "Pierre Barre", + "avatar_url": "https://avatars.githubusercontent.com/u/45085843?v=4", + "profile": "https://www.legaltile.com/", + "contributions": [ + "code" + ] + }, + { + "login": "spoutn1k", + "name": "Jean-Baptiste Skutnik", + "avatar_url": "https://avatars.githubusercontent.com/u/22240065?v=4", + "profile": "http://skutnik.page", + "contributions": [ + "code" + ] + }, + { + "login": "packetsource", + "name": "Adam Chappell", + "avatar_url": "https://avatars.githubusercontent.com/u/6276475?v=4", + "profile": "http://blog.packetsource.net/", + "contributions": [ + "code" + ] } ], "contributorsPerLine": 7, diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 00000000..c236b04c --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1,3 @@ +github: eugeny +open_collective: tabby +ko_fi: eugeny diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 66ef82a8..3174f403 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -2,9 +2,9 @@ name: Rust on: push: - branches: [ master ] + branches: [ main ] pull_request: - branches: [ master ] + branches: [ main ] env: CARGO_TERM_COLOR: always @@ -27,6 +27,18 @@ jobs: with: package: russh + Formatting: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + + - name: Install rustfmt + run: rustup component add rustfmt + + - name: rustfmt + run: cargo fmt --check + Clippy: runs-on: ubuntu-latest diff --git a/Cargo.toml b/Cargo.toml index d4ed211c..8ba84e34 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,8 +1,26 @@ [workspace] -members = [ "russh-keys", "russh", "russh-config", "cryptovec"] +members = ["russh-keys", "russh", "russh-config", "cryptovec", "pageant"] [patch.crates-io] russh = { path = "russh" } russh-keys = { path = "russh-keys" } russh-cryptovec = { path = "cryptovec" } russh-config = { path = "russh-config" } + +[workspace.dependencies] +aes = "0.8" +async-trait = "0.1" +byteorder = "1.4" +digest = "0.10" +futures = "0.3" +hmac = "0.12" +log = "0.4" +openssl = { version = "0.10" } +rand = "0.8" +sha1 = { version = "0.10", features = ["oid"] } +sha2 = { version = "0.10", features = ["oid"] } +ssh-encoding = "0.2" +ssh-key = { version = "0.6", features = ["ed25519", "rsa", "encryption"] } +thiserror = "1.0" +tokio = { version = "1.17.0" } +tokio-stream = { version = "0.1", features = ["net", "sync"] } diff --git a/README.md b/README.md index 76346d74..44b5dd36 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,13 @@ # Russh + [![Rust](https://github.com/warp-tech/russh/actions/workflows/rust.yml/badge.svg)](https://github.com/warp-tech/russh/actions/workflows/rust.yml) -[![All Contributors](https://img.shields.io/badge/all_contributors-19-orange.svg?style=flat-square)](#contributors-) +[![All Contributors](https://img.shields.io/badge/all_contributors-35-orange.svg?style=flat-square)](#contributors-) Low-level Tokio SSH2 client and server implementation. +Examples: [simple client](russh/examples/client_exec_simple.rs), [interactive PTY client](russh/examples/client_exec_interactive.rs), [server](russh/examples/echoserver.rs), [SFTP client](russh/examples/sftp_client.rs), [SFTP server](russh/examples/sftp_server.rs). + This is a fork of [Thrussh](https://nest.pijul.com/pijul/thrussh) by Pierre-Étienne Meunier. > ✨ = added in Russh @@ -14,17 +17,26 @@ This is a fork of [Thrussh](https://nest.pijul.com/pijul/thrussh) by Pierre-Éti * `direct-tcpip` (local port forwarding) * `forward-tcpip` (remote port forwarding) ✨ * `direct-streamlocal` (local UNIX socket forwarding, client only) ✨ +* `forward-streamlocal` (remote UNIX socket forwarding) ✨ * Ciphers: * `chacha20-poly1305@openssh.com` * `aes256-gcm@openssh.com` ✨ * `aes256-ctr` ✨ * `aes192-ctr` ✨ * `aes128-ctr` ✨ + * `aes256-cbc` ✨ + * `aes192-cbc` ✨ + * `aes128-cbc` ✨ + * `3des-cbc` ✨ * Key exchanges: * `curve25519-sha256@libssh.org` * `diffie-hellman-group1-sha1` ✨ * `diffie-hellman-group14-sha1` ✨ * `diffie-hellman-group14-sha256` ✨ + * `diffie-hellman-group16-sha512` ✨ + * `ecdh-sha2-nistp256` ✨ + * `ecdh-sha2-nistp384` ✨ + * `ecdh-sha2-nistp521` ✨ * MACs: * `hmac-sha1` ✨ * `hmac-sha2-256` ✨ @@ -32,15 +44,25 @@ This is a fork of [Thrussh](https://nest.pijul.com/pijul/thrussh) by Pierre-Éti * `hmac-sha1-etm@openssh.com` ✨ * `hmac-sha2-256-etm@openssh.com` ✨ * `hmac-sha2-512-etm@openssh.com` ✨ -* Host keys: +* Host keys and public key auth: * `ssh-ed25519` * `rsa-sha2-256` * `rsa-sha2-512` * `ssh-rsa` ✨ + * `ecdsa-sha2-nistp256` ✨ + * `ecdsa-sha2-nistp384` ✨ + * `ecdsa-sha2-nistp521` ✨ +* Authentication methods: + * `password` + * `publickey` + * `keyboard-interactive` + * `none` + * OpenSSH certificates (client only ✨) * Dependency updates * OpenSSH keepalive request handling ✨ * OpenSSH agent forwarding channels ✨ * OpenSSH `server-sig-algs` extension ✨ +* `openssl` dependency is optional ✨ ## Safety @@ -60,7 +82,7 @@ This is a fork of [Thrussh](https://nest.pijul.com/pijul/thrussh) by Pierre-Éti ## Ecosystem -* [russh-sftp](https://crates.io/crates/russh-sftp) - server-side SFTP subsystem support for `russh` - see `russh/examples/sftp_server.rs`. +* [russh-sftp](https://crates.io/crates/russh-sftp) - server-side and client-side SFTP subsystem support for `russh` - see `russh/examples/sftp_server.rs` or `russh/examples/sftp_client.rs`. * [async-ssh2-tokio](https://crates.io/crates/async-ssh2-tokio) - simple high-level API for running commands over SSH. ## Contributors ✨ @@ -96,6 +118,26 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d Saksham Mittal
Saksham Mittal

💻 Lucas Kent
Lucas Kent

💻 Raphael Druon
Raphael Druon

💻 + Maya the bee
Maya the bee

💻 + Milo Mirate
Milo Mirate

💻 + + + George Hopkins
George Hopkins

💻 + Åke Amcoff
Åke Amcoff

💻 + Brendon Ho
Brendon Ho

💻 + Samuel Ainsworth
Samuel Ainsworth

💻 + Sherlock Holo
Sherlock Holo

💻 + Alessandro Ricottone
Alessandro Ricottone

💻 + T0b1-iOS
T0b1-iOS

💻 + + + Shoaib Merchant
Shoaib Merchant

💻 + Michael Gleason
Michael Gleason

💻 + Ana Gelez
Ana Gelez

💻 + Tom König
Tom König

💻 + Pierre Barre
Pierre Barre

💻 + Jean-Baptiste Skutnik
Jean-Baptiste Skutnik

💻 + Adam Chappell
Adam Chappell

💻 diff --git a/cryptovec/Cargo.toml b/cryptovec/Cargo.toml index 92450230..e04cd6f5 100644 --- a/cryptovec/Cargo.toml +++ b/cryptovec/Cargo.toml @@ -7,8 +7,11 @@ include = ["Cargo.toml", "src/lib.rs"] license = "Apache-2.0" name = "russh-cryptovec" repository = "https://github.com/warp-tech/russh" -version = "0.7.0" +version = "0.7.3" +rust-version = "1.60" [dependencies] libc = "0.2" + +[target.'cfg(target_os = "windows")'.dependencies] winapi = {version = "0.3", features = ["basetsd", "minwindef", "memoryapi"]} diff --git a/cryptovec/src/lib.rs b/cryptovec/src/lib.rs index 8ecd1f0d..256b7fc2 100644 --- a/cryptovec/src/lib.rs +++ b/cryptovec/src/lib.rs @@ -248,7 +248,15 @@ impl CryptoVec { let next_capacity = size.next_power_of_two(); let old_ptr = self.p; let next_layout = std::alloc::Layout::from_size_align_unchecked(next_capacity, 1); - self.p = std::alloc::alloc_zeroed(next_layout); + let new_ptr = std::alloc::alloc_zeroed(next_layout); + if new_ptr.is_null() { + #[allow(clippy::panic)] + { + panic!("Realloc failed, pointer = {:?} {:?}", self, size) + } + } + + self.p = new_ptr; mlock(self.p, next_capacity); if self.capacity > 0 { @@ -261,15 +269,8 @@ impl CryptoVec { std::alloc::dealloc(old_ptr, layout); } - if self.p.is_null() { - #[allow(clippy::panic)] - { - panic!("Realloc failed, pointer = {:?} {:?}", self, size) - } - } else { - self.capacity = next_capacity; - self.size = size; - } + self.capacity = next_capacity; + self.size = size; } } } @@ -429,3 +430,23 @@ impl Drop for CryptoVec { } } } + +#[cfg(test)] +mod tests { + use super::*; + + // If `resize` is called with a size that is too large to be allocated, it + // should panic, and not segfault or fail silently. + #[test] + fn large_resize_panics() { + let result = std::panic::catch_unwind(|| { + let mut vec = CryptoVec::new(); + // Write something into the vector, so that there is something to + // copy when reallocating, to test all code paths. + vec.push(42); + + vec.resize(1_000_000_000_000) + }); + assert!(result.is_err()); + } +} diff --git a/pageant/Cargo.toml b/pageant/Cargo.toml new file mode 100644 index 00000000..2622b465 --- /dev/null +++ b/pageant/Cargo.toml @@ -0,0 +1,27 @@ +[package] +authors = ["Eugene "] +description = "Pageant SSH agent transport client." +documentation = "https://docs.rs/pageant" +edition = "2018" +license = "Apache-2.0" +name = "pageant" +repository = "https://github.com/warp-tech/russh" +version = "0.0.1-beta.3" +rust-version = "1.65" + +[dependencies] +futures = { workspace = true } +thiserror = { workspace = true } +rand = { workspace = true } +tokio = { workspace = true, features = ["io-util", "rt"] } +bytes = "1.7" +delegate = "0.12" + +[target.'cfg(windows)'.dependencies] +windows = { version = "0.58", features = [ + "Win32_UI_WindowsAndMessaging", + "Win32_System_Memory", + "Win32_Security", + "Win32_System_Threading", + "Win32_System_DataExchange", +] } diff --git a/pageant/src/lib.rs b/pageant/src/lib.rs new file mode 100644 index 00000000..7af6c0c6 --- /dev/null +++ b/pageant/src/lib.rs @@ -0,0 +1,11 @@ +//! # Pageant SSH agent transport protocol implementation +//! +//! This crate provides a [PageantStream] type that implements [AsyncRead] and [AsyncWrite] traits and can be used to talk to a running Pageant instance. +//! +//! This crate only implements the transport, not the actual SSH agent protocol. + +#[cfg(windows)] +mod pageant_impl; + +#[cfg(windows)] +pub use pageant_impl::*; diff --git a/pageant/src/pageant_impl.rs b/pageant/src/pageant_impl.rs new file mode 100644 index 00000000..a02fe0a1 --- /dev/null +++ b/pageant/src/pageant_impl.rs @@ -0,0 +1,285 @@ +use std::io::IoSlice; +use std::mem::size_of; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use bytes::BytesMut; +use delegate::delegate; +use thiserror::Error; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, DuplexStream, ReadBuf}; +use windows::core::HSTRING; +use windows::Win32::Foundation::{CloseHandle, HANDLE, HWND, INVALID_HANDLE_VALUE, LPARAM, WPARAM}; +use windows::Win32::Security::{ + GetTokenInformation, InitializeSecurityDescriptor, SetSecurityDescriptorOwner, TokenUser, + PSECURITY_DESCRIPTOR, SECURITY_ATTRIBUTES, SECURITY_DESCRIPTOR, TOKEN_QUERY, TOKEN_USER, +}; +use windows::Win32::System::DataExchange::COPYDATASTRUCT; +use windows::Win32::System::Memory::{ + CreateFileMappingW, MapViewOfFile, UnmapViewOfFile, FILE_MAP_WRITE, MEMORY_MAPPED_VIEW_ADDRESS, + PAGE_READWRITE, +}; +use windows::Win32::System::Threading::{GetCurrentProcess, OpenProcessToken}; +use windows::Win32::UI::WindowsAndMessaging::{FindWindowW, SendMessageA, WM_COPYDATA}; + +#[derive(Error, Debug)] +pub enum Error { + #[error("Pageant not found")] + NotFound, + + #[error("Buffer overflow")] + Overflow, + + #[error("No response from Pageant")] + NoResponse, + + #[error(transparent)] + WindowsError(#[from] windows::core::Error), +} + +impl Error { + fn from_win32() -> Self { + Self::WindowsError(windows::core::Error::from_win32()) + } +} + +/// Pageant transport stream. Implements [AsyncRead] and [AsyncWrite]. +/// +/// The stream has a unique cookie and requests made in the same stream are considered the same "session". +pub struct PageantStream { + stream: DuplexStream, +} + +impl PageantStream { + pub fn new() -> Self { + let (one, mut two) = tokio::io::duplex(_AGENT_MAX_MSGLEN * 100); + + let cookie = rand::random::().to_string(); + tokio::spawn(async move { + let mut buf = BytesMut::new(); + while let Ok(n) = two.read_buf(&mut buf).await { + if n == 0 { + break; + } + let msg = buf.split().freeze(); + let response = query_pageant_direct(cookie.clone(), &msg).unwrap(); + two.write_all(&response).await? + } + std::io::Result::Ok(()) + }); + + Self { stream: one } + } +} + +impl Default for PageantStream { + fn default() -> Self { + Self::new() + } +} + +impl AsyncRead for PageantStream { + delegate! { + to Pin::new(&mut self.stream) { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll>; + + } + } +} + +impl AsyncWrite for PageantStream { + delegate! { + to Pin::new(&mut self.stream) { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll>; + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>; + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll>; + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>; + } + + to Pin::new(&self.stream) { + fn is_write_vectored(&self) -> bool; + } + } +} + +struct MemoryMap { + filemap: HANDLE, + view: MEMORY_MAPPED_VIEW_ADDRESS, + length: usize, + pos: usize, +} + +impl MemoryMap { + fn new( + name: String, + length: usize, + security_attributes: Option, + ) -> Result { + let filemap = unsafe { + CreateFileMappingW( + INVALID_HANDLE_VALUE, + security_attributes.map(|sa| &sa as *const _), + PAGE_READWRITE, + 0, + length as u32, + &HSTRING::from(name.clone()), + ) + }?; + if filemap.is_invalid() { + return Err(Error::from_win32()); + } + let view = unsafe { MapViewOfFile(filemap, FILE_MAP_WRITE, 0, 0, 0) }; + Ok(Self { + filemap, + view, + length, + pos: 0, + }) + } + + fn seek(&mut self, pos: usize) { + self.pos = pos; + } + + fn write(&mut self, data: &[u8]) -> Result<(), Error> { + if self.pos + data.len() > self.length { + return Err(Error::Overflow); + } + + unsafe { + std::ptr::copy_nonoverlapping( + &data[0] as *const u8, + self.view.Value.add(self.pos) as *mut u8, + data.len(), + ); + } + self.pos += data.len(); + Ok(()) + } + + fn read(&mut self, n: usize) -> Vec { + let out = vec![0; n]; + unsafe { + std::ptr::copy_nonoverlapping( + self.view.Value.add(self.pos) as *const u8, + out.as_ptr() as *mut u8, + n, + ); + } + self.pos += n; + out + } +} + +impl Drop for MemoryMap { + fn drop(&mut self) { + unsafe { + let _ = UnmapViewOfFile(self.view); + let _ = CloseHandle(self.filemap); + } + } +} + +fn find_pageant_window() -> Result { + let w = unsafe { FindWindowW(&HSTRING::from("Pageant"), &HSTRING::from("Pageant")) }?; + if w.is_invalid() { + return Err(Error::NotFound); + } + Ok(w) +} + +const _AGENT_COPYDATA_ID: u64 = 0x804E50BA; +const _AGENT_MAX_MSGLEN: usize = 8192; + +pub fn is_pageant_running() -> bool { + find_pageant_window().is_ok() +} + +/// Send a one-off query to Pageant and return a response. +pub fn query_pageant_direct(cookie: String, msg: &[u8]) -> Result, Error> { + let hwnd = find_pageant_window()?; + let map_name = format!("PageantRequest{cookie}"); + + let user = unsafe { + let mut process_token = HANDLE::default(); + OpenProcessToken( + GetCurrentProcess(), + TOKEN_QUERY, + &mut process_token as *mut _, + )?; + + let mut info_size = 0; + let _ = GetTokenInformation(process_token, TokenUser, None, 0, &mut info_size); + + let mut buffer = vec![0; info_size as usize]; + GetTokenInformation( + process_token, + TokenUser, + Some(buffer.as_mut_ptr() as *mut _), + buffer.len() as u32, + &mut info_size, + )?; + let user: TOKEN_USER = *(buffer.as_ptr() as *const _); + let _ = CloseHandle(process_token); + user + }; + + let mut sd = SECURITY_DESCRIPTOR::default(); + let sa = SECURITY_ATTRIBUTES { + lpSecurityDescriptor: &mut sd as *mut _ as *mut _, + bInheritHandle: true.into(), + ..Default::default() + }; + + let psd = PSECURITY_DESCRIPTOR(&mut sd as *mut _ as *mut _); + + unsafe { + InitializeSecurityDescriptor(psd, 1)?; + SetSecurityDescriptorOwner(psd, user.User.Sid, false)?; + } + + let mut map: MemoryMap = MemoryMap::new(map_name.clone(), _AGENT_MAX_MSGLEN, Some(sa))?; + map.write(msg)?; + + let mut char_buffer = map_name.as_bytes().to_vec(); + char_buffer.push(0); + let cds = COPYDATASTRUCT { + dwData: _AGENT_COPYDATA_ID as usize, + cbData: char_buffer.len() as u32, + lpData: char_buffer.as_ptr() as *mut _, + }; + + let response = unsafe { + SendMessageA( + hwnd, + WM_COPYDATA, + WPARAM(size_of::()), + LPARAM(&cds as *const _ as isize), + ) + }; + + if response.0 == 0 { + return Err(Error::NoResponse); + } + + map.seek(0); + let mut buf = map.read(4); + let size = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize; + buf.extend(map.read(size)); + + Ok(buf) +} diff --git a/russh-config/Cargo.toml b/russh-config/Cargo.toml index 1ad2846b..0f776b96 100644 --- a/russh-config/Cargo.toml +++ b/russh-config/Cargo.toml @@ -7,12 +7,14 @@ include = ["Cargo.toml", "src/lib.rs", "src/proxy.rs"] license = "Apache-2.0" name = "russh-config" repository = "https://github.com/warp-tech/russh" -version = "0.7.0" +version = "0.7.1" +rust-version = "1.65" [dependencies] -dirs-next = "2.0" -futures = "0.3" -log = "0.4" -thiserror = "1.0" -tokio = {version = "1.0", features = ["io-util", "net", "macros", "process"]} +home = "0.5" +futures = { workspace = true } +globset = "0.4.14" +log = { workspace = true } +thiserror = { workspace = true } +tokio = { workspace = true, features = ["io-util", "net", "macros", "process"] } whoami = "1.2" diff --git a/russh-config/src/lib.rs b/russh-config/src/lib.rs index cdf4d95a..2dec9406 100644 --- a/russh-config/src/lib.rs +++ b/russh-config/src/lib.rs @@ -8,6 +8,7 @@ use std::io::Read; use std::net::ToSocketAddrs; use std::path::Path; +use globset::Glob; use log::debug; use thiserror::*; @@ -34,6 +35,7 @@ pub struct Config { pub port: u16, pub identity_file: Option, pub proxy_command: Option, + pub proxy_jump: Option, pub add_keys_to_agent: AddKeysToAgent, } @@ -45,22 +47,30 @@ impl Config { port: 22, identity_file: None, proxy_command: None, + proxy_jump: None, add_keys_to_agent: AddKeysToAgent::default(), } } } impl Config { - fn update_proxy_command(&mut self) { - if let Some(ref mut prox) = self.proxy_command { - *prox = prox.replace("%h", &self.host_name); - *prox = prox.replace("%p", &format!("{}", self.port)); - } + // Look for any of the ssh_config(5) percent-style tokens and expand them + // based on current data in the struct, returning a new String. This function + // can be employed late/lazy eg just before establishing a stream using ProxyCommand + // but also can be used to modify Hostname as config parse time + fn expand_tokens(&self, original: &str) -> String { + let mut string = original.to_string(); + string = string.replace("%u", &self.user); + string = string.replace("%h", &self.host_name); // remote hostname (from context "host") + string = string.replace("%H", &self.host_name); // remote hostname (from context "host") + string = string.replace("%p", &format!("{}", self.port)); // original typed hostname (from context "host") + string = string.replace("%%", "%"); + string } - pub async fn stream(&mut self) -> Result { - self.update_proxy_command(); + pub async fn stream(&self) -> Result { if let Some(ref proxy_command) = self.proxy_command { + let proxy_command = self.expand_tokens(proxy_command); let cmd: Vec<&str> = proxy_command.split(' ').collect(); Stream::proxy_command(cmd.first().unwrap_or(&""), cmd.get(1..).unwrap_or(&[])) .await @@ -76,7 +86,7 @@ impl Config { } pub fn parse_home(host: &str) -> Result { - let mut home = if let Some(home) = dirs_next::home_dir() { + let mut home = if let Some(home) = home::home_dir() { home } else { return Err(Error::NoHome); @@ -93,26 +103,21 @@ pub fn parse_path>(path: P, host: &str) -> Result parse(&s, host) } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)] pub enum AddKeysToAgent { Yes, Confirm, Ask, + #[default] No, } -impl Default for AddKeysToAgent { - fn default() -> Self { - AddKeysToAgent::No - } -} - pub fn parse(file: &str, host: &str) -> Result { let mut config: Option = None; for line in file.lines() { - let line = line.trim(); - if let Some(n) = line.find(' ') { - let (key, value) = line.split_at(n); + let tokens = line.trim().splitn(2, ' ').collect::>(); + if tokens.len() == 2 { + let (key, value) = (tokens.first().unwrap_or(&""), tokens.get(1).unwrap_or(&"")); let lower = key.to_lowercase(); if let Some(ref mut config) = config { match lower.as_str() { @@ -121,10 +126,7 @@ pub fn parse(file: &str, host: &str) -> Result { config.user.clear(); config.user.push_str(value.trim_start()); } - "hostname" => { - config.host_name.clear(); - config.host_name.push_str(value.trim_start()) - } + "hostname" => config.host_name = config.expand_tokens(value.trim_start()), "port" => { if let Ok(port) = value.trim_start().parse() { config.port = port @@ -133,7 +135,7 @@ pub fn parse(file: &str, host: &str) -> Result { "identityfile" => { let id = value.trim_start(); if id.starts_with("~/") { - if let Some(mut home) = dirs_next::home_dir() { + if let Some(mut home) = home::home_dir() { home.push(id.split_at(2).1); config.identity_file = Some( home.to_str() @@ -153,6 +155,7 @@ pub fn parse(file: &str, host: &str) -> Result { } } "proxycommand" => config.proxy_command = Some(value.trim_start().to_string()), + "proxyjump" => config.proxy_jump = Some(value.trim_start().to_string()), "addkeystoagent" => match value.to_lowercase().as_str() { "yes" => config.add_keys_to_agent = AddKeysToAgent::Yes, "confirm" => config.add_keys_to_agent = AddKeysToAgent::Confirm, @@ -163,7 +166,11 @@ pub fn parse(file: &str, host: &str) -> Result { debug!("{:?}", key); } } - } else if lower.as_str() == "host" && value.trim_start() == host { + } else if lower.as_str() == "host" + && value + .split_whitespace() + .any(|x| check_host_against_glob_pattern(host, x)) + { let mut c = Config::default(host); c.port = 22; config = Some(c) @@ -176,3 +183,10 @@ pub fn parse(file: &str, host: &str) -> Result { Err(Error::HostNotFound) } } + +fn check_host_against_glob_pattern(candidate: &str, glob_pattern: &str) -> bool { + match Glob::new(glob_pattern) { + Ok(glob) => glob.compile_matcher().is_match(candidate), + _ => false, + } +} diff --git a/russh-keys/Cargo.toml b/russh-keys/Cargo.toml index fd5bcc80..80fbb610 100644 --- a/russh-keys/Cargo.toml +++ b/russh-keys/Cargo.toml @@ -4,71 +4,76 @@ description = "Deal with SSH keys: load them, decrypt them, call an SSH agent." documentation = "https://docs.rs/russh-keys" edition = "2018" homepage = "https://github.com/warp-tech/russh" -include = [ - "Cargo.toml", - "src/lib.rs", - "src/agent/mod.rs", - "src/agent/msg.rs", - "src/agent/server.rs", - "src/agent/client.rs", - "src/bcrypt_pbkdf.rs", - "src/blowfish.rs", - "src/encoding.rs", - "src/format/mod.rs", - "src/format/openssh.rs", - "src/format/pkcs5.rs", - "src/format/pkcs8.rs", - "src/key.rs", - "src/signature.rs", -] keywords = ["ssh"] license = "Apache-2.0" name = "russh-keys" repository = "https://github.com/warp-tech/russh" -version = "0.37.1" +version = "0.46.0-beta.3" +rust-version = "1.65" [dependencies] -aes = "0.8" -async-trait = "0.1.72" +aes = { workspace = true } +async-trait = { workspace = true } bcrypt-pbkdf = "0.10" -bit-vec = "0.6" cbc = "0.1" ctr = "0.9" block-padding = { version = "0.3", features = ["std"] } -byteorder = "1.4" +byteorder = { workspace = true } data-encoding = "2.3" -dirs = "5.0" -ed25519-dalek = { version= "2.0", features = ["rand_core"] } -futures = "0.3" -hmac = "0.12" +digest = { workspace = true } +der = "0.7" +home = "0.5" +ecdsa = "0.16" +ed25519-dalek = { version = "2.0", features = ["rand_core", "pkcs8"] } +elliptic-curve = "0.13" +futures = { workspace = true } +hmac = { workspace = true } inout = { version = "0.1", features = ["std"] } -log = "0.4" +log = { workspace = true } md5 = "0.7" -num-bigint = "0.4" num-integer = "0.1" -openssl = { version = "0.10", optional = true } +openssl = { workspace = true, optional = true } +p256 = "0.13" +p384 = "0.13" +p521 = "0.13" pbkdf2 = "0.11" -rand = "0.7" +pkcs1 = "0.7" +pkcs5 = "0.7" +pkcs8 = { version = "0.10", features = ["pkcs5", "encryption"] } +rand = { workspace = true } rand_core = { version = "0.6.4", features = ["std"] } +rsa = "0.9" russh-cryptovec = { version = "0.7.0", path = "../cryptovec" } +sec1 = { version = "0.7", features = ["pkcs8"] } serde = { version = "1.0", features = ["derive"] } -sha2 = "0.10" -thiserror = "1.0" -tokio = { version = "1.17.0", features = [ +sha1 = { workspace = true } +sha2 = { workspace = true } +spki = "0.7" +ssh-encoding = { workspace = true } +ssh-key = { workspace = true } +thiserror = { workspace = true } +tokio = { workspace = true, features = [ "io-util", "rt-multi-thread", "time", "net", ] } -tokio-stream = { version = "0.1", features = ["net"] } -yasna = { version = "0.5.0", features = ["bit-vec", "num-bigint"] } +tokio-stream = { workspace = true } +typenum = "1.17" +yasna = { version = "0.5.0", features = ["bit-vec", "num-bigint"], optional = true } +zeroize = "1.7" [features] vendored-openssl = ["openssl", "openssl/vendored"] +legacy-ed25519-pkcs8-parser = ["yasna"] + +[target.'cfg(windows)'.dependencies] +pageant = { version = "0.0.1-beta.3", path = "../pageant" } [dev-dependencies] env_logger = "0.10" tempdir = "0.3" +tokio = { workspace = true, features = ["test-util", "macros", "process"] } [package.metadata.docs.rs] features = ["openssl"] diff --git a/russh-keys/src/agent/client.rs b/russh-keys/src/agent/client.rs index 1e730030..82d78980 100644 --- a/russh-keys/src/agent/client.rs +++ b/russh-keys/src/agent/client.rs @@ -1,7 +1,7 @@ use std::convert::TryFrom; use byteorder::{BigEndian, ByteOrder}; -use log::{debug, info}; +use log::debug; use russh_cryptovec::CryptoVec; use tokio; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; @@ -9,16 +9,35 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use super::{msg, Constraint}; use crate::encoding::{Encoding, Reader}; use crate::key::{PublicKey, SignatureHash}; -use crate::{key, Error}; +use crate::{key, protocol, Error, PublicKeyBase64}; + +pub trait AgentStream: AsyncRead + AsyncWrite {} + +impl AgentStream for S {} /// SSH agent client. -pub struct AgentClient { +pub struct AgentClient { stream: S, buf: CryptoVec, } +impl AgentClient { + /// Wraps the internal stream in a Box, allowing different client + /// implementations to have the same type + pub fn dynamic(self) -> AgentClient> { + AgentClient { + stream: Box::new(self.stream), + buf: self.buf, + } + } + + pub fn into_inner(self) -> Box { + Box::new(self.stream) + } +} + // https://tools.ietf.org/html/draft-miller-ssh-agent-00#section-4.1 -impl AgentClient { +impl AgentClient { /// Build a future that connects to an SSH agent via the provided /// stream (on Unix, usually a Unix-domain socket). pub fn connect(stream: S) -> Self { @@ -31,7 +50,7 @@ impl AgentClient { #[cfg(unix)] impl AgentClient { - /// Build a future that connects to an SSH agent via the provided + /// Connect to an SSH agent via the provided /// stream (on Unix, usually a Unix-domain socket). pub async fn connect_uds>(path: P) -> Result { let stream = tokio::net::UnixStream::connect(path).await?; @@ -41,8 +60,8 @@ impl AgentClient { }) } - /// Build a future that connects to an SSH agent via the provided - /// stream (on Unix, usually a Unix-domain socket). + /// Connect to an SSH agent specified by the SSH_AUTH_SOCK + /// environment variable. pub async fn connect_env() -> Result { let var = if let Ok(var) = std::env::var("SSH_AUTH_SOCK") { var @@ -58,16 +77,39 @@ impl AgentClient { } } -#[cfg(not(unix))] -impl AgentClient { - /// Build a future that connects to an SSH agent via the provided - /// stream (on Unix, usually a Unix-domain socket). - pub async fn connect_env() -> Result { - Err(Error::AgentFailure) +#[cfg(windows)] +const ERROR_PIPE_BUSY: u32 = 231u32; + +#[cfg(windows)] +impl AgentClient { + /// Connect to a running Pageant instance + pub async fn connect_pageant() -> Self { + Self::connect(pageant::PageantStream::new()) + } +} + +#[cfg(windows)] +impl AgentClient { + /// Connect to an SSH agent via a Windows named pipe + pub async fn connect_named_pipe>(path: P) -> Result { + let stream = loop { + match tokio::net::windows::named_pipe::ClientOptions::new().open(path.as_ref()) { + Ok(client) => break client, + Err(e) if e.raw_os_error() == Some(ERROR_PIPE_BUSY as i32) => (), + Err(e) => return Err(e.into()), + } + + tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; + }; + + Ok(AgentClient { + stream, + buf: CryptoVec::new(), + }) } } -impl AgentClient { +impl AgentClient { async fn read_response(&mut self) -> Result<(), Error> { // Writing the message self.stream.write_all(&self.buf).await?; @@ -87,6 +129,15 @@ impl AgentClient { Ok(()) } + async fn read_success(&mut self) -> Result<(), Error> { + self.read_response().await?; + if self.buf.first() == Some(&msg::SUCCESS) { + Ok(()) + } else { + Err(Error::AgentFailure) + } + } + /// Send a key to the agent, with a (possibly empty) slice of /// constraints to apply when using the key to sign. pub async fn add_identity( @@ -94,6 +145,8 @@ impl AgentClient { key: &key::KeyPair, constraints: &[Constraint], ) -> Result<(), Error> { + // See IETF draft-miller-ssh-agent-13, section 3.2 for format. + // https://datatracker.ietf.org/doc/html/draft-miller-ssh-agent self.buf.clear(); self.buf.resize(4); if constraints.is_empty() { @@ -110,33 +163,27 @@ impl AgentClient { self.buf.extend(pair.verifying_key().as_bytes()); self.buf.extend_ssh_string(b""); } - #[cfg(feature = "openssl")] #[allow(clippy::unwrap_used)] // key is known to be private key::KeyPair::RSA { ref key, .. } => { self.buf.extend_ssh_string(b"ssh-rsa"); - self.buf.extend_ssh_mpint(&key.n().to_vec()); - self.buf.extend_ssh_mpint(&key.e().to_vec()); - self.buf.extend_ssh_mpint(&key.d().to_vec()); - if let Some(iqmp) = key.iqmp() { - self.buf.extend_ssh_mpint(&iqmp.to_vec()); - } else { - let mut ctx = openssl::bn::BigNumContext::new()?; - let mut iqmp = openssl::bn::BigNum::new()?; - iqmp.mod_inverse(key.p().unwrap(), key.q().unwrap(), &mut ctx)?; - self.buf.extend_ssh_mpint(&iqmp.to_vec()); - } - self.buf.extend_ssh_mpint(&key.p().unwrap().to_vec()); - self.buf.extend_ssh_mpint(&key.q().unwrap().to_vec()); - self.buf.extend_ssh_string(b""); + self.buf + .extend_ssh(&protocol::RsaPrivateKey::try_from(key)?); + } + key::KeyPair::EC { ref key } => { + self.buf.extend_ssh_string(key.algorithm().as_bytes()); + self.buf.extend_ssh_string(key.ident().as_bytes()); + self.buf + .extend_ssh_string(&key.to_public_key().to_sec1_bytes()); + self.buf.extend_ssh_mpint(&key.to_secret_bytes()); + self.buf.extend_ssh_string(b""); // comment } } if !constraints.is_empty() { - self.buf.push_u32_be(constraints.len() as u32); for cons in constraints { match *cons { Constraint::KeyLifetime { seconds } => { self.buf.push(msg::CONSTRAIN_LIFETIME); - self.buf.push_u32_be(seconds) + self.buf.push_u32_be(seconds); } Constraint::Confirm => self.buf.push(msg::CONSTRAIN_CONFIRM), Constraint::Extensions { @@ -153,7 +200,7 @@ impl AgentClient { let len = self.buf.len() - 4; BigEndian::write_u32(&mut self.buf[..], len as u32); - self.read_response().await?; + self.read_success().await?; Ok(()) } @@ -243,34 +290,12 @@ impl AgentClient { let mut r = self.buf.reader(1); let n = r.read_u32()?; for _ in 0..n { - let key = r.read_string()?; - let _ = r.read_string()?; - let mut r = key.reader(0); - let t = r.read_string()?; - debug!("t = {:?}", std::str::from_utf8(t)); - match t { - #[cfg(feature = "openssl")] - b"ssh-rsa" => { - let e = r.read_mpint()?; - let n = r.read_mpint()?; - use openssl::bn::BigNum; - use openssl::pkey::PKey; - use openssl::rsa::Rsa; - keys.push(PublicKey::RSA { - key: key::OpenSSLPKey(PKey::from_rsa(Rsa::from_public_components( - BigNum::from_slice(n)?, - BigNum::from_slice(e)?, - )?)?), - hash: SignatureHash::SHA2_512, - }) - } - b"ssh-ed25519" => keys.push(PublicKey::Ed25519( - ed25519_dalek::VerifyingKey::try_from(r.read_string()?)?, - )), - t => { - info!("Unsupported key type: {:?}", std::str::from_utf8(t)) - } - } + let key_blob = r.read_string()?; + let _comment = r.read_string()?; + keys.push(key::parse_public_key( + key_blob, + Some(SignatureHash::SHA2_512), + )?); } } @@ -322,7 +347,6 @@ impl AgentClient { self.buf.extend_ssh_string(data); debug!("public = {:?}", public); let hash = match public { - #[cfg(feature = "openssl")] PublicKey::RSA { hash, .. } => match hash { SignatureHash::SHA2_256 => 2, SignatureHash::SHA2_512 => 4, @@ -467,8 +491,8 @@ impl AgentClient { self.buf.clear(); self.buf.resize(4); self.buf.push(msg::REMOVE_ALL_IDENTITIES); - BigEndian::write_u32(&mut self.buf[..], 5); - self.read_response().await?; + BigEndian::write_u32(&mut self.buf[..], 1); + self.read_success().await?; Ok(()) } @@ -505,14 +529,11 @@ impl AgentClient { fn key_blob(public: &key::PublicKey, buf: &mut CryptoVec) -> Result<(), Error> { match *public { - #[cfg(feature = "openssl")] PublicKey::RSA { ref key, .. } => { buf.extend(&[0, 0, 0, 0]); let len0 = buf.len(); buf.extend_ssh_string(b"ssh-rsa"); - let rsa = key.0.rsa()?; - buf.extend_ssh_mpint(&rsa.e().to_vec()); - buf.extend_ssh_mpint(&rsa.n().to_vec()); + buf.extend_ssh(&protocol::RsaPublicKey::from(key)); let len1 = buf.len(); #[allow(clippy::indexing_slicing)] // length is known BigEndian::write_u32(&mut buf[5..], (len1 - len0) as u32); @@ -526,6 +547,9 @@ fn key_blob(public: &key::PublicKey, buf: &mut CryptoVec) -> Result<(), Error> { #[allow(clippy::indexing_slicing)] // length is known BigEndian::write_u32(&mut buf[5..], (len1 - len0) as u32); } + PublicKey::EC { .. } => { + buf.extend_ssh_string(&public.public_key_bytes()); + } } Ok(()) } diff --git a/russh-keys/src/agent/msg.rs b/russh-keys/src/agent/msg.rs index a77c5091..d732e674 100644 --- a/russh-keys/src/agent/msg.rs +++ b/russh-keys/src/agent/msg.rs @@ -19,4 +19,5 @@ pub const EXTENSION: u8 = 27; pub const CONSTRAIN_LIFETIME: u8 = 1; pub const CONSTRAIN_CONFIRM: u8 = 2; -pub const CONSTRAIN_EXTENSION: u8 = 3; +// pub const CONSTRAIN_MAXSIGN: u8 = 3; +pub const CONSTRAIN_EXTENSION: u8 = 255; diff --git a/russh-keys/src/agent/server.rs b/russh-keys/src/agent/server.rs index c61a8e0d..be89509a 100644 --- a/russh-keys/src/agent/server.rs +++ b/russh-keys/src/agent/server.rs @@ -15,8 +15,6 @@ use {std, tokio}; use super::{msg, Constraint}; use crate::encoding::{Encoding, Position, Reader}; -#[cfg(feature = "openssl")] -use crate::key::SignatureHash; use crate::{key, Error}; #[derive(Clone)] @@ -252,82 +250,27 @@ impl Result { - let pos0 = r.position; - let t = r.read_string()?; - let (blob, key) = match t { - b"ssh-ed25519" => { - let pos1 = r.position; - let concat = r.read_string()?; - let _comment = r.read_string()?; - #[allow(clippy::indexing_slicing)] // length checked before - let secret = ed25519_dalek::SigningKey::try_from( - concat.get(..32).ok_or(Error::KeyIsCorrupt)?, - ).map_err(|_| Error::KeyIsCorrupt)?; + let (blob, key_pair) = { + use ssh_encoding::{Decode, Encode}; - writebuf.push(msg::SUCCESS); + let private_key = ssh_key::private::PrivateKey::new( + ssh_key::private::KeypairData::decode(&mut r)?, + "", + )?; + let _comment = r.read_string()?; + let key_pair = key::KeyPair::try_from(&private_key)?; - #[allow(clippy::indexing_slicing)] // positions checked before - (self.buf[pos0..pos1].to_vec(), key::KeyPair::Ed25519(secret)) - } - #[cfg(feature = "openssl")] - b"ssh-rsa" => { - use openssl::bn::{BigNum, BigNumContext}; - use openssl::rsa::Rsa; - let n = r.read_mpint()?; - let e = r.read_mpint()?; - let d = BigNum::from_slice(r.read_mpint()?)?; - let q_inv = r.read_mpint()?; - let p = BigNum::from_slice(r.read_mpint()?)?; - let q = BigNum::from_slice(r.read_mpint()?)?; - let (dp, dq) = { - let one = BigNum::from_u32(1)?; - let p1 = p.as_ref() - one.as_ref(); - let q1 = q.as_ref() - one.as_ref(); - let mut context = BigNumContext::new()?; - let mut dp = BigNum::new()?; - let mut dq = BigNum::new()?; - dp.checked_rem(&d, &p1, &mut context)?; - dq.checked_rem(&d, &q1, &mut context)?; - (dp, dq) - }; - let _comment = r.read_string()?; - let key = Rsa::from_private_components( - BigNum::from_slice(n)?, - BigNum::from_slice(e)?, - d, - p, - q, - dp, - dq, - BigNum::from_slice(q_inv)?, - )?; - - let len0 = writebuf.len(); - writebuf.extend_ssh_string(b"ssh-rsa"); - writebuf.extend_ssh_mpint(e); - writebuf.extend_ssh_mpint(n); + let mut blob = Vec::new(); + private_key.public_key().key_data().encode(&mut blob)?; - #[allow(clippy::indexing_slicing)] // length is known - let blob = writebuf[len0..].to_vec(); - writebuf.resize(len0); - writebuf.push(msg::SUCCESS); - ( - blob, - key::KeyPair::RSA { - key, - hash: SignatureHash::SHA2_256, - }, - ) - } - _ => return Ok(false), + (blob, key_pair) }; + writebuf.push(msg::SUCCESS); let mut w = self.keys.0.write().or(Err(Error::AgentFailure))?; let now = SystemTime::now(); if constrained { - let n = r.read_u32()?; let mut c = Vec::new(); - for _ in 0..n { - let t = r.read_byte()?; + while let Ok(t) = r.read_byte() { if t == msg::CONSTRAIN_LIFETIME { let seconds = r.read_u32()?; c.push(Constraint::KeyLifetime { seconds }); @@ -352,9 +295,9 @@ impl, + pkey: PKey, +} + +impl RsaPublic { + pub fn verify_detached(&self, hash: &SignatureHash, msg: &[u8], sig: &[u8]) -> bool { + openssl::sign::Verifier::new(message_digest_for(hash), &self.pkey) + .and_then(|mut v| v.verify_oneshot(sig, msg)) + .unwrap_or(false) + } +} + +impl TryFrom<&protocol::RsaPublicKey<'_>> for RsaPublic { + type Error = Error; + + fn try_from(pk: &protocol::RsaPublicKey<'_>) -> Result { + let key = Rsa::from_public_components( + BigNum::from_slice(&pk.modulus)?, + BigNum::from_slice(&pk.public_exponent)?, + )?; + Ok(Self { + pkey: PKey::from_rsa(key.clone())?, + key, + }) + } +} + +impl<'a> From<&RsaPublic> for protocol::RsaPublicKey<'a> { + fn from(key: &RsaPublic) -> Self { + Self { + modulus: key.key.n().to_vec().into(), + public_exponent: key.key.e().to_vec().into(), + } + } +} + +impl PartialEq for RsaPublic { + fn eq(&self, b: &RsaPublic) -> bool { + self.pkey.public_eq(&b.pkey) + } +} + +impl Eq for RsaPublic {} + +impl std::fmt::Debug for RsaPublic { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "RsaPublic {{ (hidden) }}") + } +} + +#[derive(Clone)] +pub struct RsaPrivate { + key: Rsa, + pkey: PKey, +} + +impl RsaPrivate { + pub fn new( + sk: &protocol::RsaPrivateKey<'_>, + extra: Option<&RsaCrtExtra<'_>>, + ) -> Result { + let (d, p, q) = ( + BigNum::from_slice(&sk.private_exponent)?, + BigNum::from_slice(&sk.prime1)?, + BigNum::from_slice(&sk.prime2)?, + ); + let (dp, dq) = if let Some(extra) = extra { + ( + BigNum::from_slice(&extra.dp)?, + BigNum::from_slice(&extra.dq)?, + ) + } else { + calc_dp_dq(d.as_ref(), p.as_ref(), q.as_ref())? + }; + let key = Rsa::from_private_components( + BigNum::from_slice(&sk.public_key.modulus)?, + BigNum::from_slice(&sk.public_key.public_exponent)?, + d, + p, + q, + dp, + dq, + BigNum::from_slice(&sk.coefficient)?, + )?; + key.check_key()?; + Ok(Self { + pkey: PKey::from_rsa(key.clone())?, + key, + }) + } + + pub fn new_from_der(der: &[u8]) -> Result { + let key = Rsa::private_key_from_der(der)?; + key.check_key()?; + Ok(Self { + pkey: PKey::from_rsa(key.clone())?, + key, + }) + } + + pub fn generate(bits: usize) -> Result { + let key = Rsa::generate(bits as u32)?; + Ok(Self { + pkey: PKey::from_rsa(key.clone())?, + key, + }) + } + + pub fn sign(&self, hash: &SignatureHash, msg: &[u8]) -> Result, Error> { + Ok( + openssl::sign::Signer::new(message_digest_for(hash), &self.pkey)? + .sign_oneshot_to_vec(msg)?, + ) + } +} + +impl<'a> TryFrom<&RsaPrivate> for protocol::RsaPrivateKey<'a> { + type Error = Error; + + fn try_from(key: &RsaPrivate) -> Result, Self::Error> { + let key = &key.key; + // We always set these. + if let (Some(p), Some(q), Some(iqmp)) = (key.p(), key.q(), key.iqmp()) { + Ok(protocol::RsaPrivateKey { + public_key: protocol::RsaPublicKey { + modulus: key.n().to_vec().into(), + public_exponent: key.e().to_vec().into(), + }, + private_exponent: key.d().to_vec().into(), + prime1: p.to_vec().into(), + prime2: q.to_vec().into(), + coefficient: iqmp.to_vec().into(), + comment: b"".as_slice().into(), + }) + } else { + Err(Error::KeyIsCorrupt) + } + } +} + +impl<'a> TryFrom<&RsaPrivate> for RsaCrtExtra<'a> { + type Error = Error; + + fn try_from(key: &RsaPrivate) -> Result, Self::Error> { + let key = &key.key; + // We always set these. + if let (Some(dp), Some(dq)) = (key.dmp1(), key.dmq1()) { + Ok(RsaCrtExtra { + dp: dp.to_vec().into(), + dq: dq.to_vec().into(), + }) + } else { + Err(Error::KeyIsCorrupt) + } + } +} + +impl<'a> From<&RsaPrivate> for protocol::RsaPublicKey<'a> { + fn from(key: &RsaPrivate) -> Self { + Self { + modulus: key.key.n().to_vec().into(), + public_exponent: key.key.e().to_vec().into(), + } + } +} + +impl TryFrom<&RsaPrivate> for RsaPublic { + type Error = Error; + + fn try_from(key: &RsaPrivate) -> Result { + let key = Rsa::from_public_components(key.key.n().to_owned()?, key.key.e().to_owned()?)?; + Ok(Self { + pkey: PKey::from_rsa(key.clone())?, + key, + }) + } +} + +fn message_digest_for(hash: &SignatureHash) -> MessageDigest { + match hash { + SignatureHash::SHA2_256 => MessageDigest::sha256(), + SignatureHash::SHA2_512 => MessageDigest::sha512(), + SignatureHash::SHA1 => MessageDigest::sha1(), + } +} + +fn calc_dp_dq(d: &BigNumRef, p: &BigNumRef, q: &BigNumRef) -> Result<(BigNum, BigNum), Error> { + let one = BigNum::from_u32(1)?; + let p1 = p - one.as_ref(); + let q1 = q - one.as_ref(); + let mut context = BigNumContext::new()?; + let mut dp = BigNum::new()?; + let mut dq = BigNum::new()?; + dp.checked_rem(d, &p1, &mut context)?; + dq.checked_rem(d, &q1, &mut context)?; + Ok((dp, dq)) +} diff --git a/russh-keys/src/backend_rust.rs b/russh-keys/src/backend_rust.rs new file mode 100644 index 00000000..9b568887 --- /dev/null +++ b/russh-keys/src/backend_rust.rs @@ -0,0 +1,184 @@ +use std::convert::TryFrom; + +use rsa::traits::{PrivateKeyParts, PublicKeyParts}; +use rsa::BigUint; + +use crate::key::{RsaCrtExtra, SignatureHash}; +use crate::{protocol, Error}; + +#[derive(Clone, PartialEq, Eq)] +pub struct RsaPublic { + key: rsa::RsaPublicKey, +} + +impl RsaPublic { + pub fn verify_detached(&self, hash: &SignatureHash, msg: &[u8], sig: &[u8]) -> bool { + self.key + .verify(signature_scheme_for_hash(hash), &hash_msg(hash, msg), sig) + .is_ok() + } +} + +impl TryFrom<&protocol::RsaPublicKey<'_>> for RsaPublic { + type Error = Error; + + fn try_from(pk: &protocol::RsaPublicKey<'_>) -> Result { + Ok(Self { + key: rsa::RsaPublicKey::new( + BigUint::from_bytes_be(&pk.modulus), + BigUint::from_bytes_be(&pk.public_exponent), + )?, + }) + } +} + +impl<'a> From<&RsaPublic> for protocol::RsaPublicKey<'a> { + fn from(key: &RsaPublic) -> Self { + Self { + modulus: key.key.n().to_bytes_be().into(), + public_exponent: key.key.e().to_bytes_be().into(), + } + } +} + +impl std::fmt::Debug for RsaPublic { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "RsaPublic {{ (hidden) }}") + } +} + +#[derive(Clone)] +pub struct RsaPrivate { + key: rsa::RsaPrivateKey, +} + +impl RsaPrivate { + pub fn new( + sk: &protocol::RsaPrivateKey<'_>, + extra: Option<&RsaCrtExtra<'_>>, + ) -> Result { + let mut key = rsa::RsaPrivateKey::from_components( + BigUint::from_bytes_be(&sk.public_key.modulus), + BigUint::from_bytes_be(&sk.public_key.public_exponent), + BigUint::from_bytes_be(&sk.private_exponent), + vec![ + BigUint::from_bytes_be(&sk.prime1), + BigUint::from_bytes_be(&sk.prime2), + ], + )?; + key.validate()?; + key.precompute()?; + + if Some(BigUint::from_bytes_be(&sk.coefficient)) != key.crt_coefficient() { + return Err(Error::KeyIsCorrupt); + } + if let Some(extra) = extra { + if ( + Some(&BigUint::from_bytes_be(&extra.dp)), + Some(&BigUint::from_bytes_be(&extra.dq)), + ) != (key.dp(), key.dq()) + { + return Err(Error::KeyIsCorrupt); + } + } + + Ok(Self { key }) + } + + pub fn new_from_der(der: &[u8]) -> Result { + use pkcs1::DecodeRsaPrivateKey; + Ok(Self { + key: rsa::RsaPrivateKey::from_pkcs1_der(der)?, + }) + } + + pub fn generate(bits: usize) -> Result { + Ok(Self { + key: rsa::RsaPrivateKey::new(&mut crate::key::safe_rng(), bits)?, + }) + } + + pub fn sign(&self, hash: &SignatureHash, msg: &[u8]) -> Result, Error> { + Ok(self + .key + .sign(signature_scheme_for_hash(hash), &hash_msg(hash, msg))?) + } +} + +impl<'a> TryFrom<&RsaPrivate> for protocol::RsaPrivateKey<'a> { + type Error = Error; + + fn try_from(key: &RsaPrivate) -> Result, Self::Error> { + let key = &key.key; + // We always precompute these. + if let ([p, q], Some(iqmp)) = (key.primes(), key.crt_coefficient()) { + Ok(protocol::RsaPrivateKey { + public_key: protocol::RsaPublicKey { + modulus: key.n().to_bytes_be().into(), + public_exponent: key.e().to_bytes_be().into(), + }, + private_exponent: key.d().to_bytes_be().into(), + prime1: p.to_bytes_be().into(), + prime2: q.to_bytes_be().into(), + coefficient: iqmp.to_bytes_be().into(), + comment: b"".as_slice().into(), + }) + } else { + Err(Error::KeyIsCorrupt) + } + } +} + +impl<'a> TryFrom<&RsaPrivate> for RsaCrtExtra<'a> { + type Error = Error; + + fn try_from(key: &RsaPrivate) -> Result, Self::Error> { + let key = &key.key; + // We always precompute these. + if let (Some(dp), Some(dq)) = (key.dp(), key.dq()) { + Ok(RsaCrtExtra { + dp: dp.to_bytes_be().into(), + dq: dq.to_bytes_be().into(), + }) + } else { + Err(Error::KeyIsCorrupt) + } + } +} + +impl<'a> From<&RsaPrivate> for protocol::RsaPublicKey<'a> { + fn from(key: &RsaPrivate) -> Self { + Self { + modulus: key.key.n().to_bytes_be().into(), + public_exponent: key.key.e().to_bytes_be().into(), + } + } +} + +impl TryFrom<&RsaPrivate> for RsaPublic { + type Error = Error; + + fn try_from(key: &RsaPrivate) -> Result { + Ok(Self { + key: key.key.to_public_key(), + }) + } +} + +fn signature_scheme_for_hash(hash: &SignatureHash) -> rsa::pkcs1v15::Pkcs1v15Sign { + use rsa::pkcs1v15::Pkcs1v15Sign; + match *hash { + SignatureHash::SHA2_256 => Pkcs1v15Sign::new::(), + SignatureHash::SHA2_512 => Pkcs1v15Sign::new::(), + SignatureHash::SHA1 => Pkcs1v15Sign::new::(), + } +} + +fn hash_msg(hash: &SignatureHash, msg: &[u8]) -> Vec { + use digest::Digest; + match *hash { + SignatureHash::SHA2_256 => sha2::Sha256::digest(msg).to_vec(), + SignatureHash::SHA2_512 => sha2::Sha512::digest(msg).to_vec(), + SignatureHash::SHA1 => sha1::Sha1::digest(msg).to_vec(), + } +} diff --git a/russh-keys/src/ec.rs b/russh-keys/src/ec.rs new file mode 100644 index 00000000..689ad15a --- /dev/null +++ b/russh-keys/src/ec.rs @@ -0,0 +1,263 @@ +use elliptic_curve::{Curve, CurveArithmetic, FieldBytes, FieldBytesSize}; + +use crate::key::safe_rng; +use crate::Error; + +// p521::{SigningKey, VerifyingKey} are wrapped versions and do not provide PartialEq and Eq, hence +// we make our own type alias here. +mod local_p521 { + use rand_core::CryptoRngCore; + use sha2::{Digest, Sha512}; + + pub type NistP521 = p521::NistP521; + pub type VerifyingKey = ecdsa::VerifyingKey; + pub type SigningKey = ecdsa::SigningKey; + pub type Signature = ecdsa::Signature; + pub type Result = ecdsa::Result; + + // Implement signing because p521::NistP521 does not implement DigestPrimitive trait. + pub fn try_sign_with_rng( + key: &SigningKey, + rng: &mut impl CryptoRngCore, + msg: &[u8], + ) -> Result { + use ecdsa::hazmat::{bits2field, sign_prehashed}; + use elliptic_curve::Field; + let prehash = Sha512::digest(msg); + let z = bits2field::(&prehash)?; + let k = p521::Scalar::random(rng); + sign_prehashed(key.as_nonzero_scalar().as_ref(), k, &z).map(|sig| sig.0) + } + + // Implement verifying because ecdsa::VerifyingKey does not satisfy the trait + // bound requirements of the DigestVerifier's implementation in ecdsa crate. + pub fn verify(key: &VerifyingKey, msg: &[u8], signature: &Signature) -> Result<()> { + use ecdsa::signature::hazmat::PrehashVerifier; + key.verify_prehash(&Sha512::digest(msg), signature) + } +} + +const CURVE_NISTP256: &str = "nistp256"; +const CURVE_NISTP384: &str = "nistp384"; +const CURVE_NISTP521: &str = "nistp521"; + +/// An ECC public key. +#[derive(Clone, Eq, PartialEq)] +pub enum PublicKey { + P256(p256::ecdsa::VerifyingKey), + P384(p384::ecdsa::VerifyingKey), + P521(local_p521::VerifyingKey), +} + +impl PublicKey { + /// Returns the elliptic curve domain parameter identifiers defined in RFC 5656 section 6.1. + pub fn ident(&self) -> &'static str { + match self { + Self::P256(_) => CURVE_NISTP256, + Self::P384(_) => CURVE_NISTP384, + Self::P521(_) => CURVE_NISTP521, + } + } + + /// Returns the ECC public key algorithm name defined in RFC 5656 section 6.2, in the form of + /// `"ecdsa-sha2-[identifier]"`. + pub fn algorithm(&self) -> &'static str { + match self { + Self::P256(_) => crate::ECDSA_SHA2_NISTP256, + Self::P384(_) => crate::ECDSA_SHA2_NISTP384, + Self::P521(_) => crate::ECDSA_SHA2_NISTP521, + } + } + + /// Creates a `PrivateKey` from algorithm name and SEC1-encoded point on curve. + pub fn from_sec1_bytes(algorithm: &[u8], bytes: &[u8]) -> Result { + match algorithm { + crate::KEYTYPE_ECDSA_SHA2_NISTP256 => Ok(Self::P256( + p256::ecdsa::VerifyingKey::from_sec1_bytes(bytes)?, + )), + crate::KEYTYPE_ECDSA_SHA2_NISTP384 => Ok(Self::P384( + p384::ecdsa::VerifyingKey::from_sec1_bytes(bytes)?, + )), + crate::KEYTYPE_ECDSA_SHA2_NISTP521 => Ok(Self::P521( + local_p521::VerifyingKey::from_sec1_bytes(bytes)?, + )), + _ => Err(Error::UnsupportedKeyType { + key_type_string: String::from_utf8(algorithm.to_vec()) + .unwrap_or_else(|_| format!("{algorithm:?}")), + key_type_raw: algorithm.to_vec(), + }), + } + } + + /// Returns the SEC1-encoded public curve point. + pub fn to_sec1_bytes(&self) -> Vec { + match self { + Self::P256(key) => key.to_encoded_point(false).as_bytes().to_vec(), + Self::P384(key) => key.to_encoded_point(false).as_bytes().to_vec(), + Self::P521(key) => key.to_encoded_point(false).as_bytes().to_vec(), + } + } + + /// Verifies message against signature `(r, s)` using the associated digest algorithm. + pub fn verify(&self, msg: &[u8], r: &[u8], s: &[u8]) -> Result<(), Error> { + use ecdsa::signature::Verifier; + match self { + Self::P256(key) => { + key.verify(msg, &signature_from_scalar_bytes::(r, s)?) + } + Self::P384(key) => { + key.verify(msg, &signature_from_scalar_bytes::(r, s)?) + } + Self::P521(key) => local_p521::verify( + key, + msg, + &signature_from_scalar_bytes::(r, s)?, + ), + } + .map_err(Error::from) + } +} + +impl std::fmt::Debug for PublicKey { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match *self { + Self::P256(_) => write!(f, "P256"), + Self::P384(_) => write!(f, "P384"), + Self::P521(_) => write!(f, "P521"), + } + } +} + +/// An ECC private key. +#[derive(Clone, Eq, PartialEq)] +pub enum PrivateKey { + P256(p256::ecdsa::SigningKey), + P384(p384::ecdsa::SigningKey), + P521(local_p521::SigningKey), +} + +impl PrivateKey { + /// Creates a `PrivateKey` with algorithm name and scalar. + pub fn new_from_secret_scalar(algorithm: &[u8], scalar: &[u8]) -> Result { + match algorithm { + crate::KEYTYPE_ECDSA_SHA2_NISTP256 => { + Ok(Self::P256(p256::ecdsa::SigningKey::from_slice(scalar)?)) + } + crate::KEYTYPE_ECDSA_SHA2_NISTP384 => { + Ok(Self::P384(p384::ecdsa::SigningKey::from_slice(scalar)?)) + } + crate::KEYTYPE_ECDSA_SHA2_NISTP521 => { + Ok(Self::P521(local_p521::SigningKey::from_slice(scalar)?)) + } + _ => Err(Error::UnsupportedKeyType { + key_type_string: String::from_utf8(algorithm.to_vec()) + .unwrap_or_else(|_| format!("{algorithm:?}")), + key_type_raw: algorithm.to_vec(), + }), + } + } + + /// Returns the elliptic curve domain parameter identifiers defined in RFC 5656 section 6.1. + pub fn ident(&self) -> &'static str { + match self { + Self::P256(_) => CURVE_NISTP256, + Self::P384(_) => CURVE_NISTP384, + Self::P521(_) => CURVE_NISTP521, + } + } + + /// Returns the ECC public key algorithm name defined in RFC 5656 section 6.2, in the form of + /// `"ecdsa-sha2-[identifier]"`. + pub fn algorithm(&self) -> &'static str { + match self { + Self::P256(_) => crate::ECDSA_SHA2_NISTP256, + Self::P384(_) => crate::ECDSA_SHA2_NISTP384, + Self::P521(_) => crate::ECDSA_SHA2_NISTP521, + } + } + + /// Returns the public key. + pub fn to_public_key(&self) -> PublicKey { + match self { + Self::P256(key) => PublicKey::P256(*key.verifying_key()), + Self::P384(key) => PublicKey::P384(*key.verifying_key()), + Self::P521(key) => PublicKey::P521(*key.verifying_key()), + } + } + + /// Returns the secret scalar in bytes. + pub fn to_secret_bytes(&self) -> Vec { + match self { + Self::P256(key) => key.to_bytes().to_vec(), + Self::P384(key) => key.to_bytes().to_vec(), + Self::P521(key) => key.to_bytes().to_vec(), + } + } + + /// Sign the message with associated digest algorithm. + pub fn try_sign(&self, msg: &[u8]) -> Result<(Vec, Vec), Error> { + use ecdsa::signature::RandomizedSigner; + Ok(match self { + Self::P256(key) => { + signature_to_scalar_bytes(key.try_sign_with_rng(&mut safe_rng(), msg)?) + } + Self::P384(key) => { + signature_to_scalar_bytes(key.try_sign_with_rng(&mut safe_rng(), msg)?) + } + Self::P521(key) => { + signature_to_scalar_bytes(local_p521::try_sign_with_rng(key, &mut safe_rng(), msg)?) + } + }) + } +} + +impl std::fmt::Debug for PrivateKey { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match *self { + Self::P256(_) => write!(f, "P256 {{ (hidden) }}"), + Self::P384(_) => write!(f, "P384 {{ (hidden) }}"), + Self::P521(_) => write!(f, "P521 {{ (hidden) }}"), + } + } +} + +fn try_field_bytes_from_mpint(b: &[u8]) -> Option> +where + C: Curve + CurveArithmetic, +{ + use typenum::Unsigned; + let size = FieldBytesSize::::to_usize(); + assert!(size > 0); + #[allow(clippy::indexing_slicing)] // Length checked + if b.len() == size + 1 && b[0] == 0 { + Some(FieldBytes::::clone_from_slice(&b[1..])) + } else if b.len() == size { + Some(FieldBytes::::clone_from_slice(b)) + } else if b.len() < size { + let mut fb: FieldBytes = Default::default(); + fb.as_mut_slice()[size - b.len()..].clone_from_slice(b); + Some(fb) + } else { + None + } +} + +fn signature_from_scalar_bytes(r: &[u8], s: &[u8]) -> Result, Error> +where + C: Curve + CurveArithmetic + elliptic_curve::PrimeCurve, + ecdsa::SignatureSize: elliptic_curve::generic_array::ArrayLength, +{ + Ok(ecdsa::Signature::::from_scalars( + try_field_bytes_from_mpint::(r).ok_or(Error::InvalidSignature)?, + try_field_bytes_from_mpint::(s).ok_or(Error::InvalidSignature)?, + )?) +} + +fn signature_to_scalar_bytes(sig: ecdsa::Signature) -> (Vec, Vec) +where + C: Curve + CurveArithmetic + elliptic_curve::PrimeCurve, + ecdsa::SignatureSize: elliptic_curve::generic_array::ArrayLength, +{ + let (r, s) = sig.split_bytes(); + (r.to_vec(), s.to_vec()) +} diff --git a/russh-keys/src/encoding.rs b/russh-keys/src/encoding.rs index 0f64f724..005196d3 100644 --- a/russh-keys/src/encoding.rs +++ b/russh-keys/src/encoding.rs @@ -41,6 +41,20 @@ pub trait Encoding { fn extend_list>(&mut self, list: I); /// Push an SSH-encoded empty list. fn write_empty_list(&mut self); + /// Push an SSH-encoded value. + fn extend_ssh(&mut self, v: &T) { + v.write_ssh(self) + } + /// Push a nested SSH-encoded value. + fn extend_wrapped(&mut self, write: F) + where + F: FnOnce(&mut Self); +} + +/// Trait for writing value in SSH-encoded format. +pub trait SshWrite { + /// Write the value. + fn write_ssh(&self, encoder: &mut E); } /// Encoding length of the given mpint. @@ -109,6 +123,20 @@ impl Encoding for Vec { fn write_empty_list(&mut self) { self.extend([0, 0, 0, 0]); } + + fn extend_wrapped(&mut self, write: F) + where + F: FnOnce(&mut Self), + { + let len_offset = self.len(); + #[allow(clippy::unwrap_used)] // writing into Vec<> can't panic + self.write_u32::(0).unwrap(); + let data_offset = self.len(); + write(self); + let data_len = self.len() - data_offset; + #[allow(clippy::indexing_slicing)] // length is known + BigEndian::write_u32(&mut self[len_offset..], data_len as u32); + } } impl Encoding for CryptoVec { @@ -163,6 +191,19 @@ impl Encoding for CryptoVec { fn write_empty_list(&mut self) { self.extend(&[0, 0, 0, 0]); } + + fn extend_wrapped(&mut self, write: F) + where + F: FnOnce(&mut Self), + { + let len_offset = self.len(); + self.push_u32_be(0); + let data_offset = self.len(); + write(self); + let data_len = self.len() - data_offset; + #[allow(clippy::indexing_slicing)] // length is known + BigEndian::write_u32(&mut self[len_offset..], data_len as u32); + } } /// A cursor-like trait to read SSH-encoded things. @@ -244,4 +285,30 @@ impl<'a> Position<'a> { Err(Error::IndexOutOfBounds) } } + + pub fn read_ssh>(&mut self) -> Result { + T::read_ssh(self) + } +} + +/// Trait for reading value in SSH-encoded format. +pub trait SshRead<'a>: Sized + 'a { + /// Read the value from a position. + fn read_ssh(pos: &mut Position<'a>) -> Result; +} + +impl<'a> ssh_encoding::Reader for Position<'a> { + fn read<'o>(&mut self, out: &'o mut [u8]) -> ssh_encoding::Result<&'o [u8]> { + out.copy_from_slice( + self.s + .get(self.position..(self.position + out.len())) + .ok_or(ssh_encoding::Error::Length)?, + ); + self.position += out.len(); + Ok(out) + } + + fn remaining_len(&self) -> usize { + self.s.len() - self.position + } } diff --git a/russh-keys/src/format/mod.rs b/russh-keys/src/format/mod.rs index 1463a85d..b120723c 100644 --- a/russh-keys/src/format/mod.rs +++ b/russh-keys/src/format/mod.rs @@ -1,21 +1,18 @@ use std::io::Write; -#[cfg(not(feature = "openssl"))] -use data_encoding::BASE64_MIME; -#[cfg(feature = "openssl")] use data_encoding::{BASE64_MIME, HEXLOWER_PERMISSIVE}; -#[cfg(feature = "openssl")] -use openssl::rsa::Rsa; use super::is_base64_char; use crate::{key, Error}; pub mod openssh; + +#[cfg(feature = "legacy-ed25519-pkcs8-parser")] +mod pkcs8_legacy; + pub use self::openssh::*; -#[cfg(feature = "openssl")] pub mod pkcs5; -#[cfg(feature = "openssl")] pub use self::pkcs5::*; pub mod pkcs8; @@ -33,10 +30,8 @@ pub enum Encryption { #[derive(Clone, Debug)] enum Format { - #[cfg(feature = "openssl")] Rsa, Openssh, - #[cfg(feature = "openssl")] Pkcs5Encrypted(Encryption), Pkcs8Encrypted, Pkcs8, @@ -57,35 +52,22 @@ pub fn decode_secret_key(secret: &str, password: Option<&str>) -> Result = HEXLOWER_PERMISSIVE - .decode(l.split_at(AES_128_CBC.len()).1.as_bytes())?; - if iv_.len() != 16 { - return Err(Error::CouldNotReadKey); - } - let mut iv = [0; 16]; - iv.clone_from_slice(&iv_); - format = Some(Format::Pkcs5Encrypted(Encryption::Aes128Cbc(iv))) + let iv_: Vec = + HEXLOWER_PERMISSIVE.decode(l.split_at(AES_128_CBC.len()).1.as_bytes())?; + if iv_.len() != 16 { + return Err(Error::CouldNotReadKey); } + let mut iv = [0; 16]; + iv.clone_from_slice(&iv_); + format = Some(Format::Pkcs5Encrypted(Encryption::Aes128Cbc(iv))) } } if l == "-----BEGIN OPENSSH PRIVATE KEY-----" { started = true; format = Some(Format::Openssh); } else if l == "-----BEGIN RSA PRIVATE KEY-----" { - #[cfg(not(feature = "openssl"))] - { - return Err(Error::UnsupportedKeyType { - key_type_string: "rsa".to_owned(), - key_type_raw: "rsa".as_bytes().to_vec(), - }); - } - #[cfg(feature = "openssl")] - { - started = true; - format = Some(Format::Rsa); - } + started = true; + format = Some(Format::Rsa); } else if l == "-----BEGIN ENCRYPTED PRIVATE KEY-----" { started = true; format = Some(Format::Pkcs8Encrypted); @@ -100,19 +82,28 @@ pub fn decode_secret_key(secret: &str, password: Option<&str>) -> Result decode_openssh(&secret, password), - #[cfg(feature = "openssl")] Some(Format::Rsa) => decode_rsa(&secret), - #[cfg(feature = "openssl")] Some(Format::Pkcs5Encrypted(enc)) => decode_pkcs5(&secret, password, enc), Some(Format::Pkcs8Encrypted) | Some(Format::Pkcs8) => { - self::pkcs8::decode_pkcs8(&secret, password.map(|x| x.as_bytes())) + let result = self::pkcs8::decode_pkcs8(&secret, password.map(|x| x.as_bytes())); + #[cfg(feature = "legacy-ed25519-pkcs8-parser")] + { + if result.is_err() { + let legacy_result = + pkcs8_legacy::decode_pkcs8(&secret, password.map(|x| x.as_bytes())); + if let Ok(key) = legacy_result { + return Ok(key); + } + } + } + result } None => Err(Error::CouldNotReadKey), } } pub fn encode_pkcs8_pem(key: &key::KeyPair, mut w: W) -> Result<(), Error> { - let x = self::pkcs8::encode_pkcs8(key); + let x = self::pkcs8::encode_pkcs8(key)?; w.write_all(b"-----BEGIN PRIVATE KEY-----\n")?; w.write_all(BASE64_MIME.encode(&x).as_bytes())?; w.write_all(b"\n-----END PRIVATE KEY-----\n")?; @@ -132,10 +123,9 @@ pub fn encode_pkcs8_pem_encrypted( Ok(()) } -#[cfg(feature = "openssl")] fn decode_rsa(secret: &[u8]) -> Result { Ok(key::KeyPair::RSA { - key: Rsa::private_key_from_der(secret)?, + key: crate::backend::RsaPrivate::new_from_der(secret)?, hash: key::SignatureHash::SHA2_256, }) } diff --git a/russh-keys/src/format/openssh.rs b/russh-keys/src/format/openssh.rs index 44821fb8..d0f2fdc6 100644 --- a/russh-keys/src/format/openssh.rs +++ b/russh-keys/src/format/openssh.rs @@ -1,169 +1,121 @@ use std::convert::TryFrom; -use aes::cipher::block_padding::NoPadding; -use aes::cipher::{BlockDecryptMut, KeyIvInit, StreamCipher}; -use bcrypt_pbkdf; -use ctr::Ctr64BE; -#[cfg(feature = "openssl")] -use openssl::bn::BigNum; +use ssh_key::private::{ + EcdsaKeypair, Ed25519Keypair, KeypairData, PrivateKey, RsaKeypair, RsaPrivateKey, +}; +use ssh_key::public::{Ed25519PublicKey, KeyData, RsaPublicKey}; +use ssh_key::{Algorithm, HashAlg}; -use crate::encoding::Reader; -use crate::{key, Error, KEYTYPE_ED25519, KEYTYPE_RSA}; +use crate::key::{KeyPair, PublicKey, SignatureHash}; +use crate::{ec, protocol, Error}; /// Decode a secret key given in the OpenSSH format, deciphering it if /// needed using the supplied password. -pub fn decode_openssh(secret: &[u8], password: Option<&str>) -> Result { - if matches!(secret.get(0..15), Some(b"openssh-key-v1\0")) { - let mut position = secret.reader(15); - - let ciphername = position.read_string()?; - let kdfname = position.read_string()?; - let kdfoptions = position.read_string()?; - - let nkeys = position.read_u32()?; +pub fn decode_openssh(secret: &[u8], password: Option<&str>) -> Result { + let pk = PrivateKey::from_bytes(secret)?; + KeyPair::try_from(&match password { + Some(password) => pk.decrypt(password)?, + None => pk, + }) +} - // Read all public keys - for _ in 0..nkeys { - position.read_string()?; - } +impl TryFrom<&PrivateKey> for KeyPair { + type Error = Error; - // Read all secret keys - let secret_ = position.read_string()?; - let secret = decrypt_secret_key(ciphername, kdfname, kdfoptions, password, secret_)?; - let mut position = secret.reader(0); - let _check0 = position.read_u32()?; - let _check1 = position.read_u32()?; - #[allow(clippy::never_loop)] - for _ in 0..nkeys { - // TODO check: never really loops beyond the first key - let key_type = position.read_string()?; - if key_type == KEYTYPE_ED25519 { - let pubkey = position.read_string()?; - let seckey = position.read_string()?; - let _comment = position.read_string()?; - if Some(pubkey) != seckey.get(32..) { + fn try_from(pk: &PrivateKey) -> Result { + match pk.key_data() { + KeypairData::Ed25519(Ed25519Keypair { public, private }) => { + let key = ed25519_dalek::SigningKey::from(private.as_ref()); + let public_key = ed25519_dalek::VerifyingKey::from_bytes(public.as_ref())?; + if public_key != key.verifying_key() { return Err(Error::KeyIsCorrupt); } - let secret = ed25519_dalek::SigningKey::try_from( - seckey.get(..32).ok_or(Error::KeyIsCorrupt)?, - )?; - return Ok(key::KeyPair::Ed25519(secret)); - } else if key_type == KEYTYPE_RSA && cfg!(feature = "openssl") { - #[cfg(feature = "openssl")] - { - let n = BigNum::from_slice(position.read_string()?)?; - let e = BigNum::from_slice(position.read_string()?)?; - let d = BigNum::from_slice(position.read_string()?)?; - let iqmp = BigNum::from_slice(position.read_string()?)?; - let p = BigNum::from_slice(position.read_string()?)?; - let q = BigNum::from_slice(position.read_string()?)?; - - let mut ctx = openssl::bn::BigNumContext::new()?; - let un = openssl::bn::BigNum::from_u32(1)?; - let mut p1 = openssl::bn::BigNum::new()?; - let mut q1 = openssl::bn::BigNum::new()?; - p1.checked_sub(&p, &un)?; - q1.checked_sub(&q, &un)?; - let mut dmp1 = openssl::bn::BigNum::new()?; // d mod p-1 - dmp1.checked_rem(&d, &p1, &mut ctx)?; - let mut dmq1 = openssl::bn::BigNum::new()?; // d mod q-1 - dmq1.checked_rem(&d, &q1, &mut ctx)?; - - let key = openssl::rsa::RsaPrivateKeyBuilder::new(n, e, d)? - .set_factors(p, q)? - .set_crt_params(dmp1, dmq1, iqmp)? - .build(); - key.check_key()?; - return Ok(key::KeyPair::RSA { - key, - hash: key::SignatureHash::SHA2_512, - }); + Ok(KeyPair::Ed25519(key)) + } + KeypairData::Rsa(keypair) => { + KeyPair::new_rsa_with_hash(&keypair.into(), None, SignatureHash::SHA2_512) + } + KeypairData::Ecdsa(keypair) => { + let key_type = match keypair { + EcdsaKeypair::NistP256 { .. } => crate::KEYTYPE_ECDSA_SHA2_NISTP256, + EcdsaKeypair::NistP384 { .. } => crate::KEYTYPE_ECDSA_SHA2_NISTP384, + EcdsaKeypair::NistP521 { .. } => crate::KEYTYPE_ECDSA_SHA2_NISTP521, + }; + let key = + ec::PrivateKey::new_from_secret_scalar(key_type, keypair.private_key_bytes())?; + let public_key = + ec::PublicKey::from_sec1_bytes(key_type, keypair.public_key_bytes())?; + if public_key != key.to_public_key() { + return Err(Error::KeyIsCorrupt); } - } else { - return Err(Error::UnsupportedKeyType { - key_type_string: String::from_utf8(key_type.to_vec()) - .unwrap_or_else(|_| format!("{key_type:?}")), - key_type_raw: key_type.to_vec(), - }); + Ok(KeyPair::EC { key }) } + KeypairData::Encrypted(_) => Err(Error::KeyIsEncrypted), + _ => Err(Error::UnsupportedKeyType { + key_type_string: pk.algorithm().as_str().into(), + key_type_raw: pk.algorithm().as_str().as_bytes().into(), + }), + } + } +} + +impl<'a> From<&'a RsaKeypair> for protocol::RsaPrivateKey<'a> { + fn from(key: &'a RsaKeypair) -> Self { + let RsaPublicKey { e, n } = &key.public; + let RsaPrivateKey { d, iqmp, p, q } = &key.private; + Self { + public_key: protocol::RsaPublicKey { + public_exponent: e.as_bytes().into(), + modulus: n.as_bytes().into(), + }, + private_exponent: d.as_bytes().into(), + prime1: p.as_bytes().into(), + prime2: q.as_bytes().into(), + coefficient: iqmp.as_bytes().into(), + comment: b"".as_slice().into(), } - Err(Error::CouldNotReadKey) - } else { - Err(Error::CouldNotReadKey) } } -use aes::*; +impl TryFrom<&KeyData> for PublicKey { + type Error = Error; -fn decrypt_secret_key( - ciphername: &[u8], - kdfname: &[u8], - kdfoptions: &[u8], - password: Option<&str>, - secret_key: &[u8], -) -> Result, Error> { - if kdfname == b"none" { - if password.is_none() { - Ok(secret_key.to_vec()) - } else { - Err(Error::CouldNotReadKey) + fn try_from(key_data: &KeyData) -> Result { + match key_data { + KeyData::Ed25519(Ed25519PublicKey(public)) => Ok(PublicKey::Ed25519( + ed25519_dalek::VerifyingKey::from_bytes(public)?, + )), + KeyData::Rsa(ref public) => PublicKey::new_rsa_with_hash( + &public.into(), + match key_data.algorithm() { + Algorithm::Rsa { hash } => match hash { + Some(HashAlg::Sha256) => SignatureHash::SHA2_256, + Some(HashAlg::Sha512) => SignatureHash::SHA2_512, + _ => SignatureHash::SHA1, + }, + _ => return Err(Error::KeyIsCorrupt), + }, + ), + KeyData::Ecdsa(public) => Ok(PublicKey::EC { + key: ec::PublicKey::from_sec1_bytes( + key_data.algorithm().as_str().as_bytes(), + public.as_sec1_bytes(), + )?, + }), + _ => Err(Error::UnsupportedKeyType { + key_type_string: key_data.algorithm().as_str().into(), + key_type_raw: key_data.algorithm().as_str().as_bytes().into(), + }), } - } else if let Some(password) = password { - let mut key = [0; 48]; - let n = match ciphername { - b"aes128-cbc" | b"aes128-ctr" => 32, - b"aes256-cbc" | b"aes256-ctr" => 48, - _ => return Err(Error::CouldNotReadKey), - }; - match kdfname { - b"bcrypt" => { - let mut kdfopts = kdfoptions.reader(0); - let salt = kdfopts.read_string()?; - let rounds = kdfopts.read_u32()?; - #[allow(clippy::unwrap_used)] // parameters are static - #[allow(clippy::indexing_slicing)] // output length is static - match bcrypt_pbkdf::bcrypt_pbkdf(password, salt, rounds, &mut key[..n]) { - Err(bcrypt_pbkdf::Error::InvalidParamLen) => return Err(Error::KeyIsEncrypted), - e => e.unwrap(), - } - } - _kdfname => { - return Err(Error::CouldNotReadKey); - } - }; - let (key, iv) = key.split_at(n - 16); + } +} - let mut dec = secret_key.to_vec(); - dec.resize(dec.len() + 32, 0u8); - match ciphername { - b"aes128-cbc" => { - #[allow(clippy::unwrap_used)] // parameters are static - let cipher = cbc::Decryptor::::new_from_slices(key, iv).unwrap(); - let n = cipher.decrypt_padded_mut::(&mut dec)?.len(); - dec.truncate(n) - } - b"aes256-cbc" => { - #[allow(clippy::unwrap_used)] // parameters are static - let cipher = cbc::Decryptor::::new_from_slices(key, iv).unwrap(); - let n = cipher.decrypt_padded_mut::(&mut dec)?.len(); - dec.truncate(n) - } - b"aes128-ctr" => { - #[allow(clippy::unwrap_used)] // parameters are static - let mut cipher = Ctr64BE::::new_from_slices(key, iv).unwrap(); - cipher.apply_keystream(&mut dec); - dec.truncate(secret_key.len()) - } - b"aes256-ctr" => { - #[allow(clippy::unwrap_used)] // parameters are static - let mut cipher = Ctr64BE::::new_from_slices(key, iv).unwrap(); - cipher.apply_keystream(&mut dec); - dec.truncate(secret_key.len()) - } - _ => {} +impl<'a> From<&'a RsaPublicKey> for protocol::RsaPublicKey<'a> { + fn from(key: &'a RsaPublicKey) -> Self { + let RsaPublicKey { e, n } = key; + Self { + public_exponent: e.as_bytes().into(), + modulus: n.as_bytes().into(), } - Ok(dec) - } else { - Err(Error::KeyIsEncrypted) } } diff --git a/russh-keys/src/format/pkcs5.rs b/russh-keys/src/format/pkcs5.rs index 0e5a2a5e..b1b4c266 100644 --- a/russh-keys/src/format/pkcs5.rs +++ b/russh-keys/src/format/pkcs5.rs @@ -3,9 +3,8 @@ use aes::*; use super::Encryption; use crate::{key, Error}; -/// Decode a secret key in the PKCS#5 format, possible deciphering it +/// Decode a secret key in the PKCS#5 format, possibly deciphering it /// using the supplied password. -#[cfg(feature = "openssl")] pub fn decode_pkcs5( secret: &[u8], password: Option<&str>, @@ -25,8 +24,7 @@ pub fn decode_pkcs5( #[allow(clippy::unwrap_used)] // AES parameters are static let c = cbc::Decryptor::::new_from_slices(&md5.0, &iv[..]).unwrap(); let mut dec = secret.to_vec(); - c.decrypt_padded_mut::(&mut dec)?; - dec + c.decrypt_padded_mut::(&mut dec)?.to_vec() } Encryption::Aes256Cbc(_) => unimplemented!(), }; diff --git a/russh-keys/src/format/pkcs8.rs b/russh-keys/src/format/pkcs8.rs index a9cbb96d..ceb8a313 100644 --- a/russh-keys/src/format/pkcs8.rs +++ b/russh-keys/src/format/pkcs8.rs @@ -1,259 +1,134 @@ -use std::borrow::Cow; +use std::convert::{TryFrom, TryInto}; -use aes::cipher::{BlockDecryptMut, BlockEncryptMut, KeyIvInit}; -use bit_vec::BitVec; -use block_padding::{NoPadding, Pkcs7}; -#[cfg(feature = "openssl")] -use openssl::pkey::Private; -#[cfg(feature = "openssl")] -use openssl::rsa::Rsa; -#[cfg(test)] -use rand_core::OsRng; -use std::convert::TryFrom; -use yasna::BERReaderSeq; -use {std, yasna}; +use pkcs8::{EncodePrivateKey, PrivateKeyInfo, SecretDocument}; -use super::Encryption; -#[cfg(feature = "openssl")] use crate::key::SignatureHash; -use crate::{key, Error}; - -const PBES2: &[u64] = &[1, 2, 840, 113549, 1, 5, 13]; -const PBKDF2: &[u64] = &[1, 2, 840, 113549, 1, 5, 12]; -const HMAC_SHA256: &[u64] = &[1, 2, 840, 113549, 2, 9]; -const AES256CBC: &[u64] = &[2, 16, 840, 1, 101, 3, 4, 1, 42]; -const ED25519: &[u64] = &[1, 3, 101, 112]; -#[cfg(feature = "openssl")] -const RSA: &[u64] = &[1, 2, 840, 113549, 1, 1, 1]; +use crate::{ec, key, protocol, Error}; /// Decode a PKCS#8-encoded private key. pub fn decode_pkcs8(ciphertext: &[u8], password: Option<&[u8]>) -> Result { - let secret = if let Some(pass) = password { - Cow::Owned(yasna::parse_der(ciphertext, |reader| { - reader.read_sequence(|reader| { - // Encryption parameters - let parameters = reader.next().read_sequence(|reader| { - let oid = reader.next().read_oid()?; - if oid.components().as_slice() == PBES2 { - asn1_read_pbes2(reader) - } else { - Ok(Err(Error::UnknownAlgorithm(oid))) - } - })?; - // Ciphertext - let ciphertext = reader.next().read_bytes()?; - Ok(parameters.map(|p| p.decrypt(pass, &ciphertext))) - }) - })???) + let doc = SecretDocument::try_from(ciphertext)?; + let doc = if let Some(password) = password { + doc.decode_msg::()? + .decrypt(password)? } else { - Cow::Borrowed(ciphertext) + doc }; - yasna::parse_der(&secret, |reader| { - reader.read_sequence(|reader| { - let version = reader.next().read_u64()?; - if version == 0 { - Ok(read_key_v0(reader)) - } else if version == 1 { - Ok(read_key_v1(reader)) - } else { - Ok(Err(Error::CouldNotReadKey)) + key::KeyPair::try_from(doc.decode_msg::()?) +} + +impl<'a> TryFrom> for key::KeyPair { + type Error = Error; + + fn try_from(pki: PrivateKeyInfo<'a>) -> Result { + match pki.algorithm.oid { + ed25519_dalek::pkcs8::ALGORITHM_OID => Ok(key::KeyPair::Ed25519( + ed25519_dalek::pkcs8::KeypairBytes::try_from(pki)? + .secret_key + .into(), + )), + pkcs1::ALGORITHM_OID => { + let sk = &pkcs1::RsaPrivateKey::try_from(pki.private_key)?; + key::KeyPair::new_rsa_with_hash( + &sk.into(), + Some(&sk.into()), + SignatureHash::SHA2_256, + ) } - }) - })? + sec1::ALGORITHM_OID => Ok(key::KeyPair::EC { + key: pki.try_into()?, + }), + oid => Err(Error::UnknownAlgorithm(oid)), + } + } } -fn asn1_read_pbes2( - reader: &mut yasna::BERReaderSeq, -) -> Result, yasna::ASN1Error> { - reader.next().read_sequence(|reader| { - // PBES2 has two components. - // 1. Key generation algorithm - let keygen = reader.next().read_sequence(|reader| { - let oid = reader.next().read_oid()?; - if oid.components().as_slice() == PBKDF2 { - asn1_read_pbkdf2(reader) - } else { - Ok(Err(Error::UnknownAlgorithm(oid))) - } - })?; - // 2. Encryption algorithm. - let algorithm = reader.next().read_sequence(|reader| { - let oid = reader.next().read_oid()?; - if oid.components().as_slice() == AES256CBC { - asn1_read_aes256cbc(reader) - } else { - Ok(Err(Error::UnknownAlgorithm(oid))) - } - })?; - Ok(keygen.and_then(|keygen| algorithm.map(|algo| Algorithms::Pbes2(keygen, algo)))) - }) +impl<'a> From<&pkcs1::RsaPrivateKey<'a>> for protocol::RsaPrivateKey<'a> { + fn from(sk: &pkcs1::RsaPrivateKey<'a>) -> Self { + Self { + public_key: protocol::RsaPublicKey { + public_exponent: sk.public_exponent.as_bytes().into(), + modulus: sk.modulus.as_bytes().into(), + }, + private_exponent: sk.private_exponent.as_bytes().into(), + prime1: sk.prime1.as_bytes().into(), + prime2: sk.prime2.as_bytes().into(), + coefficient: sk.coefficient.as_bytes().into(), + comment: b"".as_slice().into(), + } + } } -fn asn1_read_pbkdf2( - reader: &mut yasna::BERReaderSeq, -) -> Result, yasna::ASN1Error> { - reader.next().read_sequence(|reader| { - let salt = reader.next().read_bytes()?; - let rounds = reader.next().read_u64()?; - let digest = reader.next().read_sequence(|reader| { - let oid = reader.next().read_oid()?; - if oid.components().as_slice() == HMAC_SHA256 { - reader.next().read_null()?; - Ok(Ok(())) - } else { - Ok(Err(Error::UnknownAlgorithm(oid))) - } - })?; - Ok(digest.map(|()| KeyDerivation::Pbkdf2 { salt, rounds })) - }) +impl<'a> From<&pkcs1::RsaPrivateKey<'a>> for key::RsaCrtExtra<'a> { + fn from(sk: &pkcs1::RsaPrivateKey<'a>) -> Self { + Self { + dp: sk.exponent1.as_bytes().into(), + dq: sk.exponent2.as_bytes().into(), + } + } } -fn asn1_read_aes256cbc( - reader: &mut yasna::BERReaderSeq, -) -> Result, yasna::ASN1Error> { - let iv = reader.next().read_bytes()?; - let mut i = [0; 16]; - i.clone_from_slice(&iv); - Ok(Ok(Encryption::Aes256Cbc(i))) -} +// Note: It's infeasible to implement `EncodePrivateKey` because that is bound to `pkcs8::Result`. +impl TryFrom<&key::RsaPrivate> for SecretDocument { + type Error = Error; -fn write_key_v1(writer: &mut yasna::DERWriterSeq, secret: &ed25519_dalek::SigningKey) { - let public = ed25519_dalek::VerifyingKey::from(secret); - writer.next().write_u32(1); - // write OID - writer.next().write_sequence(|writer| { - writer - .next() - .write_oid(&ObjectIdentifier::from_slice(ED25519)); - }); - let seed = yasna::construct_der(|writer| { - writer.write_bytes( - [secret.to_bytes().as_slice(), public.as_bytes().as_slice()] - .concat() - .as_slice(), - ) - }); - writer.next().write_bytes(&seed); - writer - .next() - .write_tagged(yasna::Tag::context(1), |writer| { - writer.write_bitvec(&BitVec::from_bytes(public.as_bytes())) - }) -} + fn try_from(key: &key::RsaPrivate) -> Result { + use der::Encode; + use pkcs1::UintRef; -fn read_key_v1(reader: &mut BERReaderSeq) -> Result { - let oid = reader - .next() - .read_sequence(|reader| reader.next().read_oid())?; - if oid.components().as_slice() == ED25519 { - use ed25519_dalek::SigningKey; - let secret = { - let s = yasna::parse_der(&reader.next().read_bytes()?, |reader| reader.read_bytes())?; + let sk = protocol::RsaPrivateKey::try_from(key)?; + let extra = key::RsaCrtExtra::try_from(key)?; - s.get(..ed25519_dalek::SECRET_KEY_LENGTH) - .ok_or(Error::KeyIsCorrupt) - .and_then(|s| SigningKey::try_from(s).map_err(|_| Error::CouldNotReadKey))? + let rsa_private_key = pkcs1::RsaPrivateKey { + modulus: UintRef::new(&sk.public_key.modulus)?, + public_exponent: UintRef::new(&sk.public_key.public_exponent)?, + private_exponent: UintRef::new(&sk.private_exponent)?, + prime1: UintRef::new(&sk.prime1)?, + prime2: UintRef::new(&sk.prime2)?, + exponent1: UintRef::new(&extra.dp)?, + exponent2: UintRef::new(&extra.dq)?, + coefficient: UintRef::new(&sk.coefficient)?, + other_prime_infos: None, }; - // Consume the public key - reader - .next() - .read_tagged(yasna::Tag::context(1), |reader| reader.read_bitvec())?; - Ok(key::KeyPair::Ed25519(secret)) - } else { - Err(Error::CouldNotReadKey) + let pki = PrivateKeyInfo { + algorithm: spki::AlgorithmIdentifier { + oid: pkcs1::ALGORITHM_OID, + parameters: Some(der::asn1::Null.into()), + }, + private_key: &rsa_private_key.to_der()?, + public_key: None, + }; + Ok(Self::try_from(pki)?) } } -#[cfg(feature = "openssl")] -fn write_key_v0(writer: &mut yasna::DERWriterSeq, key: &Rsa) { - writer.next().write_u32(0); - // write OID - writer.next().write_sequence(|writer| { - writer.next().write_oid(&ObjectIdentifier::from_slice(RSA)); - writer.next().write_null() - }); - let bytes = yasna::construct_der(|writer| { - #[allow(clippy::unwrap_used)] // key is known to be private - writer.write_sequence(|writer| { - writer.next().write_u32(0); - use num_bigint::BigUint; - writer - .next() - .write_biguint(&BigUint::from_bytes_be(&key.n().to_vec())); - writer - .next() - .write_biguint(&BigUint::from_bytes_be(&key.e().to_vec())); - writer - .next() - .write_biguint(&BigUint::from_bytes_be(&key.d().to_vec())); - writer - .next() - .write_biguint(&BigUint::from_bytes_be(&key.p().unwrap().to_vec())); - writer - .next() - .write_biguint(&BigUint::from_bytes_be(&key.q().unwrap().to_vec())); - writer - .next() - .write_biguint(&BigUint::from_bytes_be(&key.dmp1().unwrap().to_vec())); - writer - .next() - .write_biguint(&BigUint::from_bytes_be(&key.dmq1().unwrap().to_vec())); - writer - .next() - .write_biguint(&BigUint::from_bytes_be(&key.iqmp().unwrap().to_vec())); - }) - }); - writer.next().write_bytes(&bytes); -} +impl TryFrom> for ec::PrivateKey { + type Error = Error; -#[cfg(feature = "openssl")] -fn read_key_v0(reader: &mut BERReaderSeq) -> Result { - let oid = reader.next().read_sequence(|reader| { - let oid = reader.next().read_oid()?; - reader.next().read_null()?; - Ok(oid) - })?; - if oid.components().as_slice() == RSA { - let seq = &reader.next().read_bytes()?; - let rsa: Result, Error> = yasna::parse_der(seq, |reader| { - reader.read_sequence(|reader| { - let version = reader.next().read_u32()?; - if version != 0 { - return Ok(Err(Error::CouldNotReadKey)); - } - use openssl::bn::BigNum; - let mut read_key = || -> Result, Error> { - Ok(Rsa::from_private_components( - BigNum::from_slice(&reader.next().read_biguint()?.to_bytes_be())?, - BigNum::from_slice(&reader.next().read_biguint()?.to_bytes_be())?, - BigNum::from_slice(&reader.next().read_biguint()?.to_bytes_be())?, - BigNum::from_slice(&reader.next().read_biguint()?.to_bytes_be())?, - BigNum::from_slice(&reader.next().read_biguint()?.to_bytes_be())?, - BigNum::from_slice(&reader.next().read_biguint()?.to_bytes_be())?, - BigNum::from_slice(&reader.next().read_biguint()?.to_bytes_be())?, - BigNum::from_slice(&reader.next().read_biguint()?.to_bytes_be())?, - )?) - }; - Ok(read_key()) - }) - })?; - Ok(key::KeyPair::RSA { - key: rsa?, - hash: SignatureHash::SHA2_256, - }) - } else { - Err(Error::CouldNotReadKey) + fn try_from(pki: PrivateKeyInfo<'_>) -> Result { + use pkcs8::AssociatedOid; + match pki.algorithm.parameters_oid()? { + p256::NistP256::OID => Ok(ec::PrivateKey::P256(pki.try_into()?)), + p384::NistP384::OID => Ok(ec::PrivateKey::P384(pki.try_into()?)), + p521::NistP521::OID => Ok(ec::PrivateKey::P521(pki.try_into()?)), + oid => Err(Error::UnknownAlgorithm(oid)), + } } } -#[cfg(not(feature = "openssl"))] -fn read_key_v0(_: &mut BERReaderSeq) -> Result { - Err(Error::CouldNotReadKey) +impl EncodePrivateKey for ec::PrivateKey { + fn to_pkcs8_der(&self) -> pkcs8::Result { + match self { + ec::PrivateKey::P256(key) => key.to_pkcs8_der(), + ec::PrivateKey::P384(key) => key.to_pkcs8_der(), + ec::PrivateKey::P521(key) => key.to_pkcs8_der(), + } + } } #[test] fn test_read_write_pkcs8() { - let secret = ed25519_dalek::SigningKey::generate(&mut OsRng {}); + let secret = ed25519_dalek::SigningKey::generate(&mut key::safe_rng()); assert_eq!( secret.verifying_key().as_bytes(), ed25519_dalek::VerifyingKey::from(&secret).as_bytes() @@ -264,176 +139,43 @@ fn test_read_write_pkcs8() { let key = decode_pkcs8(&ciphertext, Some(password)).unwrap(); match key { key::KeyPair::Ed25519 { .. } => println!("Ed25519"), - #[cfg(feature = "openssl")] + key::KeyPair::EC { .. } => println!("EC"), key::KeyPair::RSA { .. } => println!("RSA"), } } -use aes::*; -use yasna::models::ObjectIdentifier; - /// Encode a password-protected PKCS#8-encoded private key. pub fn encode_pkcs8_encrypted( pass: &[u8], rounds: u32, key: &key::KeyPair, ) -> Result, Error> { + let pvi_bytes = encode_pkcs8(key)?; + let pvi = PrivateKeyInfo::try_from(pvi_bytes.as_slice())?; + use rand::RngCore; let mut rng = rand::thread_rng(); let mut salt = [0; 64]; rng.fill_bytes(&mut salt); let mut iv = [0; 16]; rng.fill_bytes(&mut iv); - let mut dkey = [0; 32]; // AES256-CBC - pbkdf2::pbkdf2::>(pass, &salt, rounds, &mut dkey); - let mut plaintext = encode_pkcs8(key); - - let padding_len = 32 - (plaintext.len() % 32); - plaintext.extend(std::iter::repeat(padding_len as u8).take(padding_len)); - - #[allow(clippy::unwrap_used)] // parameters are static - let c = cbc::Encryptor::::new_from_slices(&dkey, &iv).unwrap(); - let n = plaintext.len(); - let encrypted = c.encrypt_padded_mut::(&mut plaintext, n)?; - Ok(yasna::construct_der(|writer| { - writer.write_sequence(|writer| { - // Encryption parameters - writer.next().write_sequence(|writer| { - writer - .next() - .write_oid(&ObjectIdentifier::from_slice(PBES2)); - asn1_write_pbes2(writer.next(), rounds as u64, &salt, &iv) - }); - // Ciphertext - writer.next().write_bytes(encrypted) - }) - })) + let doc = pvi.encrypt_with_params( + pkcs5::pbes2::Parameters::pbkdf2_sha256_aes256cbc(rounds, &salt, &iv) + .map_err(|_| Error::InvalidParameters)?, + pass, + )?; + Ok(doc.as_bytes().to_vec()) } /// Encode a Decode a PKCS#8-encoded private key. -pub fn encode_pkcs8(key: &key::KeyPair) -> Vec { - yasna::construct_der(|writer| { - writer.write_sequence(|writer| match *key { - key::KeyPair::Ed25519(ref pair) => write_key_v1(writer, pair), - #[cfg(feature = "openssl")] - key::KeyPair::RSA { ref key, .. } => write_key_v0(writer, key), - }) - }) -} - -fn asn1_write_pbes2(writer: yasna::DERWriter, rounds: u64, salt: &[u8], iv: &[u8]) { - writer.write_sequence(|writer| { - // 1. Key generation algorithm - writer.next().write_sequence(|writer| { - writer - .next() - .write_oid(&ObjectIdentifier::from_slice(PBKDF2)); - asn1_write_pbkdf2(writer.next(), rounds, salt) - }); - // 2. Encryption algorithm. - writer.next().write_sequence(|writer| { - writer - .next() - .write_oid(&ObjectIdentifier::from_slice(AES256CBC)); - writer.next().write_bytes(iv) - }); - }) -} - -fn asn1_write_pbkdf2(writer: yasna::DERWriter, rounds: u64, salt: &[u8]) { - writer.write_sequence(|writer| { - writer.next().write_bytes(salt); - writer.next().write_u64(rounds); - writer.next().write_sequence(|writer| { - writer - .next() - .write_oid(&ObjectIdentifier::from_slice(HMAC_SHA256)); - writer.next().write_null() - }) - }) -} - -enum Algorithms { - Pbes2(KeyDerivation, Encryption), -} - -impl Algorithms { - fn decrypt(&self, password: &[u8], cipher: &[u8]) -> Result, Error> { - match *self { - Algorithms::Pbes2(ref der, ref enc) => { - let mut key = enc.key(); - der.derive(password, &mut key)?; - let out = enc.decrypt(&key, cipher)?; - Ok(out) - } - } - } -} - -impl KeyDerivation { - fn derive(&self, password: &[u8], key: &mut [u8]) -> Result<(), Error> { - match *self { - KeyDerivation::Pbkdf2 { ref salt, rounds } => { - pbkdf2::pbkdf2::>(password, salt, rounds as u32, key) - // pbkdf2_hmac(password, salt, rounds as usize, digest, key)? - } - } - Ok(()) +pub fn encode_pkcs8(key: &key::KeyPair) -> Result, Error> { + let v = match *key { + key::KeyPair::Ed25519(ref pair) => pair.to_pkcs8_der()?, + key::KeyPair::RSA { ref key, .. } => SecretDocument::try_from(key)?, + key::KeyPair::EC { ref key, .. } => key.to_pkcs8_der()?, } -} - -#[derive(Debug)] -enum Key { - K128([u8; 16]), - K256([u8; 32]), -} - -impl std::ops::Deref for Key { - type Target = [u8]; - fn deref(&self) -> &[u8] { - match *self { - Key::K128(ref k) => k, - Key::K256(ref k) => k, - } - } -} - -impl std::ops::DerefMut for Key { - fn deref_mut(&mut self) -> &mut [u8] { - match *self { - Key::K128(ref mut k) => k, - Key::K256(ref mut k) => k, - } - } -} - -impl Encryption { - fn key(&self) -> Key { - match *self { - Encryption::Aes128Cbc(_) => Key::K128([0; 16]), - Encryption::Aes256Cbc(_) => Key::K256([0; 32]), - } - } - - fn decrypt(&self, key: &[u8], ciphertext: &[u8]) -> Result, Error> { - match *self { - Encryption::Aes128Cbc(ref iv) => { - #[allow(clippy::unwrap_used)] // parameters are static - let c = cbc::Decryptor::::new_from_slices(key, iv).unwrap(); - let mut dec = ciphertext.to_vec(); - Ok(c.decrypt_padded_mut::(&mut dec)?.into()) - } - Encryption::Aes256Cbc(ref iv) => { - #[allow(clippy::unwrap_used)] // parameters are static - let c = cbc::Decryptor::::new_from_slices(key, iv).unwrap(); - let mut dec = ciphertext.to_vec(); - Ok(c.decrypt_padded_mut::(&mut dec)?.into()) - } - } - } -} - -enum KeyDerivation { - Pbkdf2 { salt: Vec, rounds: u64 }, + .as_bytes() + .to_vec(); + Ok(v) } diff --git a/russh-keys/src/format/pkcs8_legacy.rs b/russh-keys/src/format/pkcs8_legacy.rs new file mode 100644 index 00000000..5553ccee --- /dev/null +++ b/russh-keys/src/format/pkcs8_legacy.rs @@ -0,0 +1,212 @@ +use std::borrow::Cow; +use std::convert::TryFrom; + +use aes::cipher::{BlockDecryptMut, KeyIvInit}; +use aes::*; +use block_padding::Pkcs7; +use yasna::BERReaderSeq; + +use super::Encryption; +use crate::{key, Error}; + +const PBES2: &[u64] = &[1, 2, 840, 113549, 1, 5, 13]; +const ED25519: &[u64] = &[1, 3, 101, 112]; +const PBKDF2: &[u64] = &[1, 2, 840, 113549, 1, 5, 12]; +const AES256CBC: &[u64] = &[2, 16, 840, 1, 101, 3, 4, 1, 42]; +const HMAC_SHA256: &[u64] = &[1, 2, 840, 113549, 2, 9]; + +pub fn decode_pkcs8(ciphertext: &[u8], password: Option<&[u8]>) -> Result { + let secret = if let Some(pass) = password { + Cow::Owned(yasna::parse_der(ciphertext, |reader| { + reader.read_sequence(|reader| { + // Encryption parameters + let parameters = reader.next().read_sequence(|reader| { + let oid = reader.next().read_oid()?; + if oid.components().as_slice() == PBES2 { + asn1_read_pbes2(reader) + } else { + Ok(Err(Error::InvalidParameters)) + } + })?; + // Ciphertext + let ciphertext = reader.next().read_bytes()?; + Ok(parameters.map(|p| p.decrypt(pass, &ciphertext))) + }) + })???) + } else { + Cow::Borrowed(ciphertext) + }; + yasna::parse_der(&secret, |reader| { + reader.read_sequence(|reader| { + let version = reader.next().read_u64()?; + if version == 0 { + Ok(Err(Error::CouldNotReadKey)) + } else if version == 1 { + Ok(read_key_v1(reader)) + } else { + Ok(Err(Error::CouldNotReadKey)) + } + }) + })? +} + +fn read_key_v1(reader: &mut BERReaderSeq) -> Result { + let oid = reader + .next() + .read_sequence(|reader| reader.next().read_oid())?; + if oid.components().as_slice() == ED25519 { + use ed25519_dalek::SigningKey; + let secret = { + let s = yasna::parse_der(&reader.next().read_bytes()?, |reader| reader.read_bytes())?; + + s.get(..ed25519_dalek::SECRET_KEY_LENGTH) + .ok_or(Error::KeyIsCorrupt) + .and_then(|s| SigningKey::try_from(s).map_err(|_| Error::CouldNotReadKey))? + }; + // Consume the public key + reader + .next() + .read_tagged(yasna::Tag::context(1), |reader| reader.read_bitvec())?; + Ok(key::KeyPair::Ed25519(secret)) + } else { + Err(Error::CouldNotReadKey) + } +} + +#[derive(Debug)] +enum Key { + K128([u8; 16]), + K256([u8; 32]), +} + +impl std::ops::Deref for Key { + type Target = [u8]; + fn deref(&self) -> &[u8] { + match *self { + Key::K128(ref k) => k, + Key::K256(ref k) => k, + } + } +} + +impl std::ops::DerefMut for Key { + fn deref_mut(&mut self) -> &mut [u8] { + match *self { + Key::K128(ref mut k) => k, + Key::K256(ref mut k) => k, + } + } +} + +enum Algorithms { + Pbes2(KeyDerivation, Encryption), +} + +impl Algorithms { + fn decrypt(&self, password: &[u8], cipher: &[u8]) -> Result, Error> { + match *self { + Algorithms::Pbes2(ref der, ref enc) => { + let mut key = enc.key(); + der.derive(password, &mut key)?; + let out = enc.decrypt(&key, cipher)?; + Ok(out) + } + } + } +} + +impl Encryption { + fn key(&self) -> Key { + match *self { + Encryption::Aes128Cbc(_) => Key::K128([0; 16]), + Encryption::Aes256Cbc(_) => Key::K256([0; 32]), + } + } + + fn decrypt(&self, key: &[u8], ciphertext: &[u8]) -> Result, Error> { + match *self { + Encryption::Aes128Cbc(ref iv) => { + #[allow(clippy::unwrap_used)] // parameters are static + let c = cbc::Decryptor::::new_from_slices(key, iv).unwrap(); + let mut dec = ciphertext.to_vec(); + Ok(c.decrypt_padded_mut::(&mut dec)?.into()) + } + Encryption::Aes256Cbc(ref iv) => { + #[allow(clippy::unwrap_used)] // parameters are static + let c = cbc::Decryptor::::new_from_slices(key, iv).unwrap(); + let mut dec = ciphertext.to_vec(); + Ok(c.decrypt_padded_mut::(&mut dec)?.into()) + } + } + } +} + +enum KeyDerivation { + Pbkdf2 { salt: Vec, rounds: u64 }, +} + +impl KeyDerivation { + fn derive(&self, password: &[u8], key: &mut [u8]) -> Result<(), Error> { + match *self { + KeyDerivation::Pbkdf2 { ref salt, rounds } => { + pbkdf2::pbkdf2::>(password, salt, rounds as u32, key) + // pbkdf2_hmac(password, salt, rounds as usize, digest, key)? + } + } + Ok(()) + } +} +fn asn1_read_pbes2( + reader: &mut yasna::BERReaderSeq, +) -> Result, yasna::ASN1Error> { + reader.next().read_sequence(|reader| { + // PBES2 has two components. + // 1. Key generation algorithm + let keygen = reader.next().read_sequence(|reader| { + let oid = reader.next().read_oid()?; + if oid.components().as_slice() == PBKDF2 { + asn1_read_pbkdf2(reader) + } else { + Ok(Err(Error::InvalidParameters)) + } + })?; + // 2. Encryption algorithm. + let algorithm = reader.next().read_sequence(|reader| { + let oid = reader.next().read_oid()?; + if oid.components().as_slice() == AES256CBC { + asn1_read_aes256cbc(reader) + } else { + Ok(Err(Error::InvalidParameters)) + } + })?; + Ok(keygen.and_then(|keygen| algorithm.map(|algo| Algorithms::Pbes2(keygen, algo)))) + }) +} + +fn asn1_read_pbkdf2( + reader: &mut yasna::BERReaderSeq, +) -> Result, yasna::ASN1Error> { + reader.next().read_sequence(|reader| { + let salt = reader.next().read_bytes()?; + let rounds = reader.next().read_u64()?; + let digest = reader.next().read_sequence(|reader| { + let oid = reader.next().read_oid()?; + if oid.components().as_slice() == HMAC_SHA256 { + reader.next().read_null()?; + Ok(Ok(())) + } else { + Ok(Err(Error::InvalidParameters)) + } + })?; + Ok(digest.map(|()| KeyDerivation::Pbkdf2 { salt, rounds })) + }) +} + +fn asn1_read_aes256cbc( + reader: &mut yasna::BERReaderSeq, +) -> Result, yasna::ASN1Error> { + let iv = reader.next().read_bytes()?; + let mut i = [0; 16]; + i.clone_from_slice(&iv); + Ok(Ok(Encryption::Aes256Cbc(i))) +} diff --git a/russh-keys/src/key.rs b/russh-keys/src/key.rs index 98ebda9f..f187d163 100644 --- a/russh-keys/src/key.rs +++ b/russh-keys/src/key.rs @@ -12,17 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. // +use std::borrow::Cow; +use std::convert::{TryFrom, TryInto}; + +pub use backend::{RsaPrivate, RsaPublic}; use ed25519_dalek::{Signer, Verifier}; -#[cfg(feature = "openssl")] -use openssl::pkey::{Private, Public}; use rand_core::OsRng; use russh_cryptovec::CryptoVec; use serde::{Deserialize, Serialize}; -use std::convert::TryFrom; use crate::encoding::{Encoding, Reader}; pub use crate::signature::*; -use crate::Error; +use crate::{backend, ec, protocol, Error}; #[derive(Debug, PartialEq, Eq, Copy, Clone)] /// Name of a public key algorithm. @@ -34,6 +35,12 @@ impl AsRef for Name { } } +/// The name of the ecdsa-sha2-nistp256 algorithm for SSH. +pub const ECDSA_SHA2_NISTP256: Name = Name("ecdsa-sha2-nistp256"); +/// The name of the ecdsa-sha2-nistp384 algorithm for SSH. +pub const ECDSA_SHA2_NISTP384: Name = Name("ecdsa-sha2-nistp384"); +/// The name of the ecdsa-sha2-nistp521 algorithm for SSH. +pub const ECDSA_SHA2_NISTP521: Name = Name("ecdsa-sha2-nistp521"); /// The name of the Ed25519 algorithm for SSH. pub const ED25519: Name = Name("ssh-ed25519"); /// The name of the ssh-sha2-512 algorithm for SSH. @@ -45,10 +52,21 @@ pub const NONE: Name = Name("none"); pub const SSH_RSA: Name = Name("ssh-rsa"); +pub static ALL_KEY_TYPES: &[&Name] = &[ + &NONE, + &SSH_RSA, + &RSA_SHA2_256, + &RSA_SHA2_512, + &ECDSA_SHA2_NISTP256, + &ECDSA_SHA2_NISTP384, + &ECDSA_SHA2_NISTP521, +]; + impl Name { /// Base name of the private key file for a key name. pub fn identity_file(&self) -> &'static str { match *self { + ECDSA_SHA2_NISTP256 | ECDSA_SHA2_NISTP384 | ECDSA_SHA2_NISTP521 => "id_ecdsa", ED25519 => "id_ed25519", RSA_SHA2_512 => "id_rsa", RSA_SHA2_256 => "id_rsa", @@ -57,6 +75,17 @@ impl Name { } } +impl TryFrom<&str> for Name { + type Error = (); + fn try_from(s: &str) -> Result { + ALL_KEY_TYPES + .iter() + .find(|x| x.0 == s) + .map(|x| **x) + .ok_or(()) + } +} + #[doc(hidden)] pub trait Verify { fn verify_client_auth(&self, buffer: &[u8], sig: &[u8]) -> bool; @@ -84,23 +113,12 @@ impl SignatureHash { } } - #[cfg(feature = "openssl")] - fn message_digest(&self) -> openssl::hash::MessageDigest { - use openssl::hash::MessageDigest; - match *self { - SignatureHash::SHA2_256 => MessageDigest::sha256(), - SignatureHash::SHA2_512 => MessageDigest::sha512(), - SignatureHash::SHA1 => MessageDigest::sha1(), - } - } - pub fn from_rsa_hostkey_algo(algo: &[u8]) -> Option { - if algo == b"rsa-sha2-256" { - Some(Self::SHA2_256) - } else if algo == b"rsa-sha2-512" { - Some(Self::SHA2_512) - } else { - Some(Self::SHA1) + match algo { + b"rsa-sha2-256" => Some(Self::SHA2_256), + b"rsa-sha2-512" => Some(Self::SHA2_512), + b"ssh-rsa" => Some(Self::SHA1), + _ => None, } } } @@ -111,108 +129,61 @@ pub enum PublicKey { #[doc(hidden)] Ed25519(ed25519_dalek::VerifyingKey), #[doc(hidden)] - #[cfg(feature = "openssl")] RSA { - key: OpenSSLPKey, + key: backend::RsaPublic, hash: SignatureHash, }, + #[doc(hidden)] + EC { key: ec::PublicKey }, } impl PartialEq for PublicKey { fn eq(&self, other: &Self) -> bool { match (self, other) { - #[cfg(feature = "openssl")] (Self::RSA { key: a, .. }, Self::RSA { key: b, .. }) => a == b, (Self::Ed25519(a), Self::Ed25519(b)) => a == b, - #[cfg(feature = "openssl")] + (Self::EC { key: a }, Self::EC { key: b }) => a == b, _ => false, } } } -/// A public key from OpenSSL. -#[cfg(feature = "openssl")] -#[derive(Clone)] -pub struct OpenSSLPKey(pub openssl::pkey::PKey); - -#[cfg(feature = "openssl")] -use std::cmp::{Eq, PartialEq}; - -#[cfg(feature = "openssl")] -impl PartialEq for OpenSSLPKey { - fn eq(&self, b: &OpenSSLPKey) -> bool { - self.0.public_eq(&b.0) - } -} -#[cfg(feature = "openssl")] -impl Eq for OpenSSLPKey {} -#[cfg(feature = "openssl")] -impl std::fmt::Debug for OpenSSLPKey { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "OpenSSLPKey {{ (hidden) }}") - } -} - impl PublicKey { /// Parse a public key in SSH format. pub fn parse(algo: &[u8], pubkey: &[u8]) -> Result { - match algo { - b"ssh-ed25519" => { - let mut p = pubkey.reader(0); - let key_algo = p.read_string()?; - let key_bytes = p.read_string()?; - if key_algo != b"ssh-ed25519" { - return Err(Error::CouldNotReadKey); - } - let Ok(key_bytes) = <&[u8; ed25519_dalek::PUBLIC_KEY_LENGTH]>::try_from(key_bytes) else { - return Err(Error::CouldNotReadKey); - }; - ed25519_dalek::VerifyingKey::from_bytes(key_bytes) - .map(PublicKey::Ed25519) - .map_err(Error::from) + use ssh_encoding::Decode; + let key_data = &ssh_key::public::KeyData::decode(&mut pubkey.reader(0))?; + let key_algo = key_data.algorithm(); + let key_algo = key_algo.as_str().as_bytes(); + if key_algo == b"ssh-rsa" { + if algo != SSH_RSA.as_ref().as_bytes() + && algo != RSA_SHA2_256.as_ref().as_bytes() + && algo != RSA_SHA2_512.as_ref().as_bytes() + { + return Err(Error::KeyIsCorrupt); } - b"ssh-rsa" | b"rsa-sha2-256" | b"rsa-sha2-512" if cfg!(feature = "openssl") => { - #[cfg(feature = "openssl")] - { - use log::debug; - let mut p = pubkey.reader(0); - let key_algo = p.read_string()?; - debug!("{:?}", std::str::from_utf8(key_algo)); - if key_algo != b"ssh-rsa" - && key_algo != b"rsa-sha2-256" - && key_algo != b"rsa-sha2-512" - { - return Err(Error::CouldNotReadKey); - } - let key_e = p.read_string()?; - let key_n = p.read_string()?; - use openssl::bn::BigNum; - use openssl::pkey::PKey; - use openssl::rsa::Rsa; - Ok(PublicKey::RSA { - key: OpenSSLPKey(PKey::from_rsa(Rsa::from_public_components( - BigNum::from_slice(key_n)?, - BigNum::from_slice(key_e)?, - )?)?), - hash: SignatureHash::from_rsa_hostkey_algo(algo) - .unwrap_or(SignatureHash::SHA1), - }) - } - #[cfg(not(feature = "openssl"))] - { - unreachable!() - } - } - _ => Err(Error::CouldNotReadKey), + } else if key_algo != algo { + return Err(Error::KeyIsCorrupt); } + Self::try_from(key_data) + } + + pub fn new_rsa_with_hash( + pk: &protocol::RsaPublicKey<'_>, + hash: SignatureHash, + ) -> Result { + Ok(PublicKey::RSA { + key: RsaPublic::try_from(pk)?, + hash, + }) } /// Algorithm name for that key. pub fn name(&self) -> &'static str { match *self { PublicKey::Ed25519(_) => ED25519.0, - #[cfg(feature = "openssl")] PublicKey::RSA { ref hash, .. } => hash.name().0, + PublicKey::EC { ref key } => key.algorithm(), } } @@ -226,17 +197,8 @@ impl PublicKey { let sig = ed25519_dalek::Signature::from_bytes(&sig); public.verify(buffer, &sig).is_ok() } - - #[cfg(feature = "openssl")] - PublicKey::RSA { ref key, ref hash } => { - use openssl::sign::*; - let verify = || { - let mut verifier = Verifier::new(hash.message_digest(), &key.0)?; - verifier.update(buffer)?; - verifier.verify(sig) - }; - verify().unwrap_or(false) - } + PublicKey::RSA { ref key, ref hash } => key.verify_detached(hash, buffer, sig), + PublicKey::EC { ref key, .. } => ec_verify(key, buffer, sig).is_ok(), } } @@ -250,21 +212,11 @@ impl PublicKey { data_encoding::BASE64_NOPAD.encode(&hasher.finalize()) } - #[cfg(feature = "openssl")] - pub fn set_algorithm(&mut self, algorithm: &[u8]) { + pub fn set_algorithm(&mut self, algorithm: SignatureHash) { if let PublicKey::RSA { ref mut hash, .. } = self { - if algorithm == b"rsa-sha2-512" { - *hash = SignatureHash::SHA2_512 - } else if algorithm == b"rsa-sha2-256" { - *hash = SignatureHash::SHA2_256 - } else if algorithm == b"ssh-rsa" { - *hash = SignatureHash::SHA1 - } + *hash = algorithm; } } - - #[cfg(not(feature = "openssl"))] - pub fn set_algorithm(&mut self, _: &[u8]) {} } impl Verify for PublicKey { @@ -280,11 +232,13 @@ impl Verify for PublicKey { #[allow(clippy::large_enum_variant)] pub enum KeyPair { Ed25519(ed25519_dalek::SigningKey), - #[cfg(feature = "openssl")] RSA { - key: openssl::rsa::Rsa, + key: backend::RsaPrivate, hash: SignatureHash, }, + EC { + key: ec::PrivateKey, + }, } impl Clone for KeyPair { @@ -294,11 +248,11 @@ impl Clone for KeyPair { Self::Ed25519(kp) => { Self::Ed25519(ed25519_dalek::SigningKey::from_bytes(&kp.to_bytes())) } - #[cfg(feature = "openssl")] Self::RSA { key, hash } => Self::RSA { key: key.clone(), hash: *hash, }, + Self::EC { key } => Self::EC { key: key.clone() }, } } } @@ -311,8 +265,8 @@ impl std::fmt::Debug for KeyPair { "Ed25519 {{ public: {:?}, secret: (hidden) }}", key.verifying_key().as_bytes() ), - #[cfg(feature = "openssl")] KeyPair::RSA { .. } => write!(f, "RSA {{ (hidden) }}"), + KeyPair::EC { .. } => write!(f, "EC {{ (hidden) }}"), } } } @@ -324,20 +278,28 @@ impl<'b> crate::encoding::Bytes for &'b KeyPair { } impl KeyPair { + pub fn new_rsa_with_hash( + sk: &protocol::RsaPrivateKey<'_>, + extra: Option<&RsaCrtExtra<'_>>, + hash: SignatureHash, + ) -> Result { + Ok(KeyPair::RSA { + key: RsaPrivate::new(sk, extra)?, + hash, + }) + } + /// Copy the public key of this algorithm. pub fn clone_public_key(&self) -> Result { Ok(match self { KeyPair::Ed25519(ref key) => PublicKey::Ed25519(key.verifying_key()), - #[cfg(feature = "openssl")] - KeyPair::RSA { ref key, ref hash } => { - use openssl::pkey::PKey; - use openssl::rsa::Rsa; - let key = Rsa::from_public_components(key.n().to_owned()?, key.e().to_owned()?)?; - PublicKey::RSA { - key: OpenSSLPKey(PKey::from_rsa(key)?), - hash: *hash, - } - } + KeyPair::RSA { ref key, ref hash } => PublicKey::RSA { + key: key.try_into()?, + hash: *hash, + }, + KeyPair::EC { ref key } => PublicKey::EC { + key: key.to_public_key(), + }, }) } @@ -345,12 +307,12 @@ impl KeyPair { pub fn name(&self) -> &'static str { match *self { KeyPair::Ed25519(_) => ED25519.0, - #[cfg(feature = "openssl")] KeyPair::RSA { ref hash, .. } => hash.name().0, + KeyPair::EC { ref key } => key.algorithm(), } } - /// Generate a key pair. + /// Generate a ED25519 key pair. pub fn generate_ed25519() -> Option { let keypair = ed25519_dalek::SigningKey::generate(&mut OsRng {}); assert_eq!( @@ -360,9 +322,9 @@ impl KeyPair { Some(KeyPair::Ed25519(keypair)) } - #[cfg(feature = "openssl")] + /// Generate a RSA key pair. pub fn generate_rsa(bits: usize, hash: SignatureHash) -> Option { - let key = openssl::rsa::Rsa::generate(bits as u32).ok()?; + let key = RsaPrivate::generate(bits).ok()?; Some(KeyPair::RSA { key, hash }) } @@ -373,11 +335,14 @@ impl KeyPair { KeyPair::Ed25519(ref secret) => Ok(Signature::Ed25519(SignatureBytes( secret.sign(to_sign).to_bytes(), ))), - #[cfg(feature = "openssl")] KeyPair::RSA { ref key, ref hash } => Ok(Signature::RSA { - bytes: rsa_signature(hash, key, to_sign)?, + bytes: key.sign(hash, to_sign)?, hash: *hash, }), + KeyPair::EC { ref key } => Ok(Signature::ECDSA { + algorithm: key.algorithm(), + signature: ec_signature(key, to_sign)?, + }), } } @@ -398,15 +363,21 @@ impl KeyPair { buffer.extend_ssh_string(ED25519.0.as_bytes()); buffer.extend_ssh_string(signature.to_bytes().as_slice()); } - #[cfg(feature = "openssl")] KeyPair::RSA { ref key, ref hash } => { // https://tools.ietf.org/html/draft-rsa-dsa-sha2-256-02#section-2.2 - let signature = rsa_signature(hash, key, to_sign.as_ref())?; + let signature = key.sign(hash, to_sign.as_ref())?; let name = hash.name(); buffer.push_u32_be((name.0.len() + signature.len() + 8) as u32); buffer.extend_ssh_string(name.0.as_bytes()); buffer.extend_ssh_string(&signature); } + KeyPair::EC { ref key } => { + let algorithm = key.algorithm().as_bytes(); + let signature = ec_signature(key, to_sign.as_ref())?; + buffer.push_u32_be((algorithm.len() + signature.len() + 8) as u32); + buffer.extend_ssh_string(algorithm); + buffer.extend_ssh_string(&signature); + } } Ok(()) } @@ -423,89 +394,88 @@ impl KeyPair { buffer.extend_ssh_string(ED25519.0.as_bytes()); buffer.extend_ssh_string(signature.to_bytes().as_slice()); } - #[cfg(feature = "openssl")] KeyPair::RSA { ref key, ref hash } => { // https://tools.ietf.org/html/draft-rsa-dsa-sha2-256-02#section-2.2 - let signature = rsa_signature(hash, key, buffer)?; + let signature = key.sign(hash, buffer)?; let name = hash.name(); buffer.push_u32_be((name.0.len() + signature.len() + 8) as u32); buffer.extend_ssh_string(name.0.as_bytes()); buffer.extend_ssh_string(&signature); } + KeyPair::EC { ref key } => { + let signature = ec_signature(key, buffer)?; + let algorithm = key.algorithm().as_bytes(); + buffer.push_u32_be((algorithm.len() + signature.len() + 8) as u32); + buffer.extend_ssh_string(algorithm); + buffer.extend_ssh_string(&signature); + } } Ok(()) } /// Create a copy of an RSA key with a specified hash algorithm. - #[cfg(feature = "openssl")] pub fn with_signature_hash(&self, hash: SignatureHash) -> Option { match self { KeyPair::Ed25519(_) => None, - #[cfg(feature = "openssl")] KeyPair::RSA { key, .. } => Some(KeyPair::RSA { key: key.clone(), hash, }), + KeyPair::EC { .. } => None, } } } -#[cfg(feature = "openssl")] -fn rsa_signature( - hash: &SignatureHash, - key: &openssl::rsa::Rsa, - b: &[u8], -) -> Result, Error> { - use openssl::pkey::*; - use openssl::rsa::*; - use openssl::sign::Signer; - let pkey = PKey::from_rsa(Rsa::from_private_components( - key.n().to_owned()?, - key.e().to_owned()?, - key.d().to_owned()?, - key.p().ok_or(Error::KeyIsCorrupt)?.to_owned()?, - key.q().ok_or(Error::KeyIsCorrupt)?.to_owned()?, - key.dmp1().ok_or(Error::KeyIsCorrupt)?.to_owned()?, - key.dmq1().ok_or(Error::KeyIsCorrupt)?.to_owned()?, - key.iqmp().ok_or(Error::KeyIsCorrupt)?.to_owned()?, - )?)?; - let mut signer = Signer::new(hash.message_digest(), &pkey)?; - signer.update(b)?; - Ok(signer.sign_to_vec()?) +/// Extra CRT parameters for RSA private key. +pub struct RsaCrtExtra<'a> { + /// `d mod (p-1)`. + pub dp: Cow<'a, [u8]>, + /// `d mod (q-1)`. + pub dq: Cow<'a, [u8]>, } -/// Parse a public key from a byte slice. -pub fn parse_public_key( - p: &[u8], - #[cfg(feature = "openssl")] prefer_hash: Option, -) -> Result { - let mut pos = p.reader(0); - let t = pos.read_string()?; - if t == b"ssh-ed25519" { - if let Ok(pubkey) = pos.read_string() { - let Ok(pubkey) = <&[u8; ed25519_dalek::PUBLIC_KEY_LENGTH]>::try_from(pubkey) else { - return Err(Error::CouldNotReadKey); - }; - let p = ed25519_dalek::VerifyingKey::from_bytes(pubkey).map_err(Error::from)?; - return Ok(PublicKey::Ed25519(p)); - } +impl Drop for RsaCrtExtra<'_> { + fn drop(&mut self) { + zeroize_cow(&mut self.dp); + zeroize_cow(&mut self.dq); } - if t == b"ssh-rsa" { - #[cfg(feature = "openssl")] - { - let e = pos.read_string()?; - let n = pos.read_string()?; - use openssl::bn::*; - use openssl::pkey::*; - use openssl::rsa::*; - return Ok(PublicKey::RSA { - key: OpenSSLPKey(PKey::from_rsa(Rsa::from_public_components( - BigNum::from_slice(n)?, - BigNum::from_slice(e)?, - )?)?), - hash: prefer_hash.unwrap_or(SignatureHash::SHA2_256), - }); - } +} + +fn ec_signature(key: &ec::PrivateKey, b: &[u8]) -> Result, Error> { + let (r, s) = key.try_sign(b)?; + let mut buf = Vec::new(); + buf.extend_ssh_mpint(&r); + buf.extend_ssh_mpint(&s); + Ok(buf) +} + +fn ec_verify(key: &ec::PublicKey, b: &[u8], sig: &[u8]) -> Result<(), Error> { + let mut reader = sig.reader(0); + key.verify(b, reader.read_mpint()?, reader.read_mpint()?) +} + +/// Parse a public key from a byte slice. +pub fn parse_public_key(p: &[u8], prefer_hash: Option) -> Result { + use ssh_encoding::Decode; + let mut key = PublicKey::try_from(&ssh_key::public::KeyData::decode(&mut p.reader(0))?)?; + key.set_algorithm(prefer_hash.unwrap_or(SignatureHash::SHA2_256)); + Ok(key) +} + +/// Obtain a cryptographic-safe random number generator. +pub fn safe_rng() -> impl rand::CryptoRng + rand::RngCore { + rand::thread_rng() +} + +/// Zeroize `Cow` if value is owned. +pub(crate) fn zeroize_cow(v: &mut Cow) +where + T: ToOwned + ?Sized, + ::Owned: zeroize::Zeroize, +{ + use zeroize::Zeroize; + match v { + Cow::Owned(v) => v.zeroize(), + Cow::Borrowed(_) => (), } - Err(Error::CouldNotReadKey) } diff --git a/russh-keys/src/lib.rs b/russh-keys/src/lib.rs index 09287aa0..25e65ee9 100644 --- a/russh-keys/src/lib.rs +++ b/russh-keys/src/lib.rs @@ -10,11 +10,11 @@ //! opening key files, deciphering encrypted keys, and dealing with //! agents. //! -//! The following example (which uses the `openssl` feature) shows how -//! to do all these in a single example: start and SSH agent server, -//! connect to it with a client, decipher an encrypted private key -//! (the password is `b"blabla"`), send it to the agent, and ask the -//! agent to sign a piece of data (`b"Please sign this", below). +//! The following example shows how to do all these in a single example: +//! start and SSH agent server, connect to it with a client, decipher +//! an encrypted private key (the password is `b"blabla"`), send it to +//! the agent, and ask the agent to sign a piece of data +//! (`b"Please sign this"`, below). //! //!``` //! use russh_keys::*; @@ -30,7 +30,7 @@ //! //! const PKCS8_ENCRYPTED: &'static str = "-----BEGIN ENCRYPTED PRIVATE KEY-----\nMIIFLTBXBgkqhkiG9w0BBQ0wSjApBgkqhkiG9w0BBQwwHAQITo1O0b8YrS0CAggA\nMAwGCCqGSIb3DQIJBQAwHQYJYIZIAWUDBAEqBBBtLH4T1KOfo1GGr7salhR8BIIE\n0KN9ednYwcTGSX3hg7fROhTw7JAJ1D4IdT1fsoGeNu2BFuIgF3cthGHe6S5zceI2\nMpkfwvHbsOlDFWMUIAb/VY8/iYxhNmd5J6NStMYRC9NC0fVzOmrJqE1wITqxtORx\nIkzqkgFUbaaiFFQPepsh5CvQfAgGEWV329SsTOKIgyTj97RxfZIKA+TR5J5g2dJY\nj346SvHhSxJ4Jc0asccgMb0HGh9UUDzDSql0OIdbnZW5KzYJPOx+aDqnpbz7UzY/\nP8N0w/pEiGmkdkNyvGsdttcjFpOWlLnLDhtLx8dDwi/sbEYHtpMzsYC9jPn3hnds\nTcotqjoSZ31O6rJD4z18FOQb4iZs3MohwEdDd9XKblTfYKM62aQJWH6cVQcg+1C7\njX9l2wmyK26Tkkl5Qg/qSfzrCveke5muZgZkFwL0GCcgPJ8RixSB4GOdSMa/hAMU\nkvFAtoV2GluIgmSe1pG5cNMhurxM1dPPf4WnD+9hkFFSsMkTAuxDZIdDk3FA8zof\nYhv0ZTfvT6V+vgH3Hv7Tqcxomy5Qr3tj5vvAqqDU6k7fC4FvkxDh2mG5ovWvc4Nb\nXv8sed0LGpYitIOMldu6650LoZAqJVv5N4cAA2Edqldf7S2Iz1QnA/usXkQd4tLa\nZ80+sDNv9eCVkfaJ6kOVLk/ghLdXWJYRLenfQZtVUXrPkaPpNXgD0dlaTN8KuvML\nUw/UGa+4ybnPsdVflI0YkJKbxouhp4iB4S5ACAwqHVmsH5GRnujf10qLoS7RjDAl\no/wSHxdT9BECp7TT8ID65u2mlJvH13iJbktPczGXt07nBiBse6OxsClfBtHkRLzE\nQF6UMEXsJnIIMRfrZQnduC8FUOkfPOSXc8r9SeZ3GhfbV/DmWZvFPCpjzKYPsM5+\nN8Bw/iZ7NIH4xzNOgwdp5BzjH9hRtCt4sUKVVlWfEDtTnkHNOusQGKu7HkBF87YZ\nRN/Nd3gvHob668JOcGchcOzcsqsgzhGMD8+G9T9oZkFCYtwUXQU2XjMN0R4VtQgZ\nrAxWyQau9xXMGyDC67gQ5xSn+oqMK0HmoW8jh2LG/cUowHFAkUxdzGadnjGhMOI2\nzwNJPIjF93eDF/+zW5E1l0iGdiYyHkJbWSvcCuvTwma9FIDB45vOh5mSR+YjjSM5\nnq3THSWNi7Cxqz12Q1+i9pz92T2myYKBBtu1WDh+2KOn5DUkfEadY5SsIu/Rb7ub\n5FBihk2RN3y/iZk+36I69HgGg1OElYjps3D+A9AjVby10zxxLAz8U28YqJZm4wA/\nT0HLxBiVw+rsHmLP79KvsT2+b4Diqih+VTXouPWC/W+lELYKSlqnJCat77IxgM9e\nYIhzD47OgWl33GJ/R10+RDoDvY4koYE+V5NLglEhbwjloo9Ryv5ywBJNS7mfXMsK\n/uf+l2AscZTZ1mhtL38efTQCIRjyFHc3V31DI0UdETADi+/Omz+bXu0D5VvX+7c6\nb1iVZKpJw8KUjzeUV8yOZhvGu3LrQbhkTPVYL555iP1KN0Eya88ra+FUKMwLgjYr\nJkUx4iad4dTsGPodwEP/Y9oX/Qk3ZQr+REZ8lg6IBoKKqqrQeBJ9gkm1jfKE6Xkc\nCog3JMeTrb3LiPHgN6gU2P30MRp6L1j1J/MtlOAr5rux\n-----END ENCRYPTED PRIVATE KEY-----\n"; //! -//! #[cfg(all(unix, feature = "openssl"))] +//! #[cfg(unix)] //! fn main() { //! env_logger::try_init().unwrap_or(()); //! let dir = tempdir::TempDir::new("russh").unwrap(); @@ -58,7 +58,7 @@ //! }).unwrap() //! } //! -//! #[cfg(any(not(unix), not(feature = "openssl")))] +//! #[cfg(not(unix))] //! fn main() {} //! //! ``` @@ -66,23 +66,35 @@ use std::borrow::Cow; use std::fs::{File, OpenOptions}; use std::io::{BufRead, BufReader, Read, Seek, SeekFrom, Write}; -use std::path::Path; +use std::path::{Path, PathBuf}; use aes::cipher::block_padding::UnpadError; use aes::cipher::inout::PadError; use byteorder::{BigEndian, WriteBytesExt}; use data_encoding::BASE64_MIME; +use hmac::{Hmac, Mac}; use log::debug; +use sha1::Sha1; +use ssh_key::Certificate; use thiserror::Error; +pub mod ec; pub mod encoding; pub mod key; +pub mod protocol; pub mod signature; mod format; pub use format::*; -/// A module to write SSH agent. +#[cfg(feature = "openssl")] +#[path = "backend_openssl.rs"] +mod backend; +#[cfg(not(feature = "openssl"))] +#[path = "backend_rust.rs"] +mod backend; + +/// OpenSSH agent protocol implementation pub mod agent; #[derive(Debug, Error)] @@ -99,6 +111,9 @@ pub enum Error { /// The type of the key is unsupported #[error("Invalid Ed25519 key data")] Ed25519KeyError(#[from] ed25519_dalek::SignatureError), + /// The type of the key is unsupported + #[error("Invalid ECDSA key data")] + EcdsaKeyError(#[from] p256::elliptic_curve::Error), /// The key is encrypted (should supply a password?) #[error("The key is encrypted")] KeyIsEncrypted, @@ -112,14 +127,18 @@ pub enum Error { #[error("The server key changed at line {}", line)] KeyChanged { line: usize }, /// The key uses an unsupported algorithm - #[error("Unknown key algorithm")] - UnknownAlgorithm(yasna::models::ObjectIdentifier), + #[error("Unknown key algorithm: {0}")] + UnknownAlgorithm(::pkcs8::ObjectIdentifier), /// Index out of bounds #[error("Index out of bounds")] IndexOutOfBounds, /// Unknown signature type #[error("Unknown signature type: {}", sig_type)] UnknownSignatureType { sig_type: String }, + #[error("Invalid signature")] + InvalidSignature, + #[error("Invalid parameters")] + InvalidParameters, /// Agent protocol error #[error("Agent protocol error")] AgentProtocolError, @@ -132,6 +151,10 @@ pub enum Error { #[error(transparent)] Openssl(#[from] openssl::error::ErrorStack), + #[cfg(not(feature = "openssl"))] + #[error("Rsa: {0}")] + Rsa(#[from] rsa::Error), + #[error(transparent)] Pad(#[from] PadError), @@ -140,8 +163,22 @@ pub enum Error { #[error("Base64 decoding error: {0}")] Decode(#[from] data_encoding::DecodeError), - #[error("ASN1 decoding error: {0}")] - ASN1(yasna::ASN1Error), + #[error("Der: {0}")] + Der(#[from] der::Error), + #[error("Spki: {0}")] + Spki(#[from] spki::Error), + #[error("Pkcs1: {0}")] + Pkcs1(#[from] pkcs1::Error), + #[error("Pkcs8: {0}")] + Pkcs8(#[from] ::pkcs8::Error), + #[error("Sec1: {0}")] + Sec1(#[from] sec1::Error), + + #[error("SshKey: {0}")] + SshKey(#[from] ssh_key::Error), + #[error("SshEncoding: {0}")] + SshEncoding(#[from] ssh_encoding::Error), + #[error("Environment variable `{0}` not found")] EnvVar(&'static str), #[error( @@ -149,18 +186,32 @@ pub enum Error { points to a nonexistent file or directory." )] BadAuthSock, + + #[error("ASN1 decoding error: {0}")] + #[cfg(feature = "legacy-ed25519-pkcs8-parser")] + LegacyASN1(::yasna::ASN1Error), + + #[cfg(windows)] + #[error("Pageant: {0}")] + Pageant(#[from] pageant::Error), } +#[cfg(feature = "legacy-ed25519-pkcs8-parser")] impl From for Error { fn from(e: yasna::ASN1Error) -> Error { - Error::ASN1(e) + Error::LegacyASN1(e) } } -const KEYTYPE_ED25519: &[u8] = b"ssh-ed25519"; -const KEYTYPE_RSA: &[u8] = b"ssh-rsa"; +const KEYTYPE_ECDSA_SHA2_NISTP256: &[u8] = ECDSA_SHA2_NISTP256.as_bytes(); +const KEYTYPE_ECDSA_SHA2_NISTP384: &[u8] = ECDSA_SHA2_NISTP384.as_bytes(); +const KEYTYPE_ECDSA_SHA2_NISTP521: &[u8] = ECDSA_SHA2_NISTP521.as_bytes(); + +const ECDSA_SHA2_NISTP256: &str = "ecdsa-sha2-nistp256"; +const ECDSA_SHA2_NISTP384: &str = "ecdsa-sha2-nistp384"; +const ECDSA_SHA2_NISTP521: &str = "ecdsa-sha2-nistp521"; -/// Load a public key from a file. Ed25519 and RSA keys are supported. +/// Load a public key from a file. Ed25519, EC-DSA and RSA keys are supported. /// /// ``` /// russh_keys::load_public_key("../files/id_ed25519.pub").unwrap(); @@ -187,11 +238,7 @@ pub fn load_public_key>(path: P) -> Result /// ``` pub fn parse_public_key_base64(key: &str) -> Result { let base = BASE64_MIME.decode(key.as_bytes())?; - key::parse_public_key( - &base, - #[cfg(feature = "openssl")] - None, - ) + key::parse_public_key(&base, None) } pub trait PublicKeyBase64 { @@ -219,17 +266,16 @@ impl PublicKeyBase64 for key::PublicKey { .unwrap(); s.extend_from_slice(publickey.as_bytes()); } - #[cfg(feature = "openssl")] key::PublicKey::RSA { ref key, .. } => { use encoding::Encoding; let name = b"ssh-rsa"; #[allow(clippy::unwrap_used)] // Vec<>.write_all can't fail s.write_u32::(name.len() as u32).unwrap(); s.extend_from_slice(name); - #[allow(clippy::unwrap_used)] // TODO check - s.extend_ssh_mpint(&key.0.rsa().unwrap().e().to_vec()); - #[allow(clippy::unwrap_used)] // TODO check - s.extend_ssh_mpint(&key.0.rsa().unwrap().n().to_vec()); + s.extend_ssh(&protocol::RsaPublicKey::from(key)); + } + key::PublicKey::EC { ref key } => { + write_ec_public_key(&mut s, key); } } s @@ -250,17 +296,29 @@ impl PublicKeyBase64 for key::KeyPair { s.write_u32::(public.len() as u32).unwrap(); s.extend_from_slice(public.as_slice()); } - #[cfg(feature = "openssl")] key::KeyPair::RSA { ref key, .. } => { use encoding::Encoding; - s.extend_ssh_mpint(&key.e().to_vec()); - s.extend_ssh_mpint(&key.n().to_vec()); + s.extend_ssh(&protocol::RsaPublicKey::from(key)); + } + key::KeyPair::EC { ref key } => { + write_ec_public_key(&mut s, &key.to_public_key()); } } s } } +fn write_ec_public_key(buf: &mut Vec, key: &ec::PublicKey) { + let algorithm = key.algorithm().as_bytes(); + let ident = key.ident().as_bytes(); + let q = key.to_sec1_bytes(); + + use encoding::Encoding; + buf.extend_ssh_string(algorithm); + buf.extend_ssh_string(ident); + buf.extend_ssh_string(&q); +} + /// Write a public key onto the provided `Write`, encoded in base-64. pub fn write_public_key_base64( mut w: W, @@ -282,15 +340,29 @@ pub fn load_secret_key>( decode_secret_key(&secret, password) } +/// Load a openssh certificate +pub fn load_openssh_certificate>(cert_: P) -> Result { + let mut cert_file = std::fs::File::open(cert_)?; + let mut cert = String::new(); + cert_file.read_to_string(&mut cert)?; + + Certificate::from_openssh(&cert) +} + fn is_base64_char(c: char) -> bool { - ('a'..='z').contains(&c) - || ('A'..='Z').contains(&c) - || ('0'..='9').contains(&c) + c.is_ascii_lowercase() + || c.is_ascii_uppercase() + || c.is_ascii_digit() || c == '/' || c == '+' || c == '=' } +/// Record a host's public key into the user's known_hosts file. +pub fn learn_known_hosts(host: &str, port: u16, pubkey: &key::PublicKey) -> Result<(), Error> { + learn_known_hosts_path(host, port, pubkey, known_hosts_path()?) +} + /// Record a host's public key into a nonstandard location. pub fn learn_known_hosts_path>( host: &str, @@ -331,17 +403,21 @@ pub fn learn_known_hosts_path>( Ok(()) } -/// Check that a server key matches the one recorded in file `path`. -pub fn check_known_hosts_path>( +/// Get the server key that matches the one recorded in the user's known_hosts file. +pub fn known_host_keys(host: &str, port: u16) -> Result, Error> { + known_host_keys_path(host, port, known_hosts_path()?) +} + +/// Get the server key that matches the one recorded in `path`. +pub fn known_host_keys_path>( host: &str, port: u16, - pubkey: &key::PublicKey, path: P, -) -> Result { +) -> Result, Error> { let mut f = if let Ok(f) = File::open(path) { BufReader::new(f) } else { - return Ok(false); + return Ok(vec![]); }; let mut buffer = String::new(); @@ -352,6 +428,7 @@ pub fn check_known_hosts_path>( }; debug!("host_port = {:?}", host_port); let mut line = 1; + let mut matches = vec![]; while f.read_line(&mut buffer)? > 0 { { if buffer.as_bytes().first() == Some(&b'#') { @@ -365,65 +442,82 @@ pub fn check_known_hosts_path>( let key = s.next(); if let (Some(h), Some(k)) = (hosts, key) { debug!("{:?} {:?}", h, k); - let host_matches = h.split(',').any(|x| x == host_port); - if host_matches { - if &parse_public_key_base64(k)? == pubkey { - return Ok(true); - } else { - return Err(Error::KeyChanged { line }); - } + if match_hostname(&host_port, h) { + matches.push((line, parse_public_key_base64(k)?)); } } } buffer.clear(); line += 1; } - Ok(false) + Ok(matches) } -/// Record a host's public key into the user's known_hosts file. -#[cfg(target_os = "windows")] -pub fn learn_known_hosts(host: &str, port: u16, pubkey: &key::PublicKey) -> Result<(), Error> { - if let Some(mut known_host_file) = dirs::home_dir() { - known_host_file.push("ssh"); - known_host_file.push("known_hosts"); - learn_known_hosts_path(host, port, pubkey, &known_host_file) - } else { - Err(Error::NoHomeDir) +fn match_hostname(host: &str, pattern: &str) -> bool { + for entry in pattern.split(',') { + if entry.starts_with("|1|") { + let mut parts = entry.split('|').skip(2); + let Some(Ok(salt)) = parts.next().map(|p| BASE64_MIME.decode(p.as_bytes())) else { + continue; + }; + let Some(Ok(hash)) = parts.next().map(|p| BASE64_MIME.decode(p.as_bytes())) else { + continue; + }; + if let Ok(hmac) = Hmac::::new_from_slice(&salt) { + if hmac.chain_update(host).verify_slice(&hash).is_ok() { + return true; + } + } + } else if host == entry { + return true; + } } + false } -/// Record a host's public key into the user's known_hosts file. -#[cfg(not(target_os = "windows"))] -pub fn learn_known_hosts(host: &str, port: u16, pubkey: &key::PublicKey) -> Result<(), Error> { - if let Some(mut known_host_file) = dirs::home_dir() { - known_host_file.push(".ssh"); - known_host_file.push("known_hosts"); - learn_known_hosts_path(host, port, pubkey, &known_host_file) - } else { - Err(Error::NoHomeDir) - } +/// Check whether the host is known, from its standard location. +pub fn check_known_hosts(host: &str, port: u16, pubkey: &key::PublicKey) -> Result { + check_known_hosts_path(host, port, pubkey, known_hosts_path()?) +} + +/// Check that a server key matches the one recorded in file `path`. +pub fn check_known_hosts_path>( + host: &str, + port: u16, + pubkey: &key::PublicKey, + path: P, +) -> Result { + let check = known_host_keys_path(host, port, path)? + .into_iter() + .map( + |(line, recorded)| match (pubkey.name() == recorded.name(), *pubkey == recorded) { + (true, true) => Ok(true), + (true, false) => Err(Error::KeyChanged { line }), + _ => Ok(false), + }, + ) + // If any Err was returned, we stop here + .collect::, Error>>()? + .into_iter() + // Now we check the results for a match + .any(|x| x); + + Ok(check) } -/// Check whether the host is known, from its standard location. #[cfg(target_os = "windows")] -pub fn check_known_hosts(host: &str, port: u16, pubkey: &key::PublicKey) -> Result { - if let Some(mut known_host_file) = dirs::home_dir() { - known_host_file.push("ssh"); - known_host_file.push("known_hosts"); - check_known_hosts_path(host, port, pubkey, &known_host_file) +fn known_hosts_path() -> Result { + if let Some(home_dir) = home::home_dir() { + Ok(home_dir.join("ssh").join("known_hosts")) } else { - Err(Error::NoHomeDir.into()) + Err(Error::NoHomeDir) } } -/// Check whether the host is known, from its standard location. #[cfg(not(target_os = "windows"))] -pub fn check_known_hosts(host: &str, port: u16, pubkey: &key::PublicKey) -> Result { - if let Some(mut known_host_file) = dirs::home_dir() { - known_host_file.push(".ssh"); - known_host_file.push("known_hosts"); - check_known_hosts_path(host, port, pubkey, &known_host_file) +fn known_hosts_path() -> Result { + if let Some(home_dir) = home::home_dir() { + Ok(home_dir.join(".ssh").join("known_hosts")) } else { Err(Error::NoHomeDir) } @@ -434,7 +528,7 @@ mod test { use std::fs::File; use std::io::Write; - #[cfg(feature = "openssl")] + #[cfg(unix)] use futures::Future; use super::*; @@ -458,7 +552,6 @@ dP3jryYgvsCIBAA5jMWSjrmnOTXhidqcOy4xYCrAttzSnZ/cUadfBenL+DQq6neffw7j8r sJWR7W+cGvJ/vLsw== -----END OPENSSH PRIVATE KEY-----"; - #[cfg(feature = "openssl")] const RSA_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY----- b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAABFwAAAAdzc2gtcn NhAAAAAwEAAQAAAQEAuSvQ9m76zhRB4m0BUKPf17lwccj7KQ1Qtse63AOqP/VYItqEH8un @@ -499,15 +592,112 @@ QR+u0AypRPmzHnOPAAAAEXJvb3RAMTQwOTExNTQ5NDBkAQ== decode_secret_key(ED25519_AESCTR_KEY, Some("test")).unwrap(); } + // Key from RFC 8410 Section 10.3. This is a key using PrivateKeyInfo structure. + const RFC8410_ED25519_PRIVATE_ONLY_KEY: &str = "-----BEGIN PRIVATE KEY----- +MC4CAQAwBQYDK2VwBCIEINTuctv5E1hK1bbY8fdp+K06/nwoy/HU++CXqI9EdVhC +-----END PRIVATE KEY-----"; + + #[test] + fn test_decode_rfc8410_ed25519_private_only_key() { + env_logger::try_init().unwrap_or(()); + assert!(matches!( + decode_secret_key(RFC8410_ED25519_PRIVATE_ONLY_KEY, None), + Ok(key::KeyPair::Ed25519 { .. }) + )); + // We always encode public key, skip test_decode_encode_symmetry. + } + + // Key from RFC 8410 Section 10.3. This is a key using OneAsymmetricKey structure. + const RFC8410_ED25519_PRIVATE_PUBLIC_KEY: &str = "-----BEGIN PRIVATE KEY----- +MHICAQEwBQYDK2VwBCIEINTuctv5E1hK1bbY8fdp+K06/nwoy/HU++CXqI9EdVhC +oB8wHQYKKoZIhvcNAQkJFDEPDA1DdXJkbGUgQ2hhaXJzgSEAGb9ECWmEzf6FQbrB +Z9w7lshQhqowtrbLDFw4rXAxZuE= +-----END PRIVATE KEY-----"; + + #[test] + fn test_decode_rfc8410_ed25519_private_public_key() { + env_logger::try_init().unwrap_or(()); + assert!(matches!( + decode_secret_key(RFC8410_ED25519_PRIVATE_PUBLIC_KEY, None), + Ok(key::KeyPair::Ed25519 { .. }) + )); + // We can't encode attributes, skip test_decode_encode_symmetry. + } + #[test] - #[cfg(feature = "openssl")] fn test_decode_rsa_secret_key() { env_logger::try_init().unwrap_or(()); decode_secret_key(RSA_KEY, None).unwrap(); } #[test] - #[cfg(feature = "openssl")] + fn test_decode_openssh_p256_secret_key() { + // Generated using: ssh-keygen -t ecdsa -b 256 -m rfc4716 -f $file + let key = "-----BEGIN OPENSSH PRIVATE KEY----- +b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAaAAAABNlY2RzYS +1zaGEyLW5pc3RwMjU2AAAACG5pc3RwMjU2AAAAQQQ/i+HCsmZZPy0JhtT64vW7EmeA1DeA +M/VnPq3vAhu+xooJ7IMMK3lUHlBDosyvA2enNbCWyvNQc25dVt4oh9RhAAAAqHG7WMFxu1 +jBAAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBD+L4cKyZlk/LQmG +1Pri9bsSZ4DUN4Az9Wc+re8CG77GignsgwwreVQeUEOizK8DZ6c1sJbK81Bzbl1W3iiH1G +EAAAAgLAmXR6IlN0SdiD6o8qr+vUr0mXLbajs/m0UlegElOmoAAAANcm9iZXJ0QGJic2Rl +dgECAw== +-----END OPENSSH PRIVATE KEY----- +"; + assert!(matches!( + decode_secret_key(key, None), + Ok(key::KeyPair::EC { + key: ec::PrivateKey::P256(_), + }) + )); + } + + #[test] + fn test_decode_openssh_p384_secret_key() { + // Generated using: ssh-keygen -t ecdsa -b 384 -m rfc4716 -f $file + let key = "-----BEGIN OPENSSH PRIVATE KEY----- +b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAiAAAABNlY2RzYS +1zaGEyLW5pc3RwMzg0AAAACG5pc3RwMzg0AAAAYQTkLnKPk/1NZD9mQ8XoebD7ASv9/svh +5jO75HF7RYAqKK3fl5wsHe4VTJAOT3qH841yTcK79l0dwhHhHeg60byL7F9xOEzr2kqGeY +Uwrl7fVaL7hfHzt6z+sG8smSQ3tF8AAADYHjjBch44wXIAAAATZWNkc2Etc2hhMi1uaXN0 +cDM4NAAAAAhuaXN0cDM4NAAAAGEE5C5yj5P9TWQ/ZkPF6Hmw+wEr/f7L4eYzu+Rxe0WAKi +it35ecLB3uFUyQDk96h/ONck3Cu/ZdHcIR4R3oOtG8i+xfcThM69pKhnmFMK5e31Wi+4Xx +87es/rBvLJkkN7RfAAAAMFzt6053dxaQT0Ta/CGfZna0nibHzxa55zgBmje/Ho3QDNlBCH +Ylv0h4Wyzto8NfLQAAAA1yb2JlcnRAYmJzZGV2AQID +-----END OPENSSH PRIVATE KEY----- +"; + assert!(matches!( + decode_secret_key(key, None), + Ok(key::KeyPair::EC { + key: ec::PrivateKey::P384(_), + }) + )); + } + + #[test] + fn test_decode_openssh_p521_secret_key() { + // Generated using: ssh-keygen -t ecdsa -b 521 -m rfc4716 -f $file + let key = "-----BEGIN OPENSSH PRIVATE KEY----- +b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAArAAAABNlY2RzYS +1zaGEyLW5pc3RwNTIxAAAACG5pc3RwNTIxAAAAhQQA7a9awmFeDjzYiuUOwMfXkKTevfQI +iGlduu8BkjBOWXpffJpKsdTyJI/xI05l34OvqfCCkPUcfFWHK+LVRGahMBgBcGB9ZZOEEq +iKNIT6C9WcJTGDqcBSzQ2yTSOxPXfUmVTr4D76vbYu5bjd9aBKx8HdfMvPeo0WD0ds/LjX +LdJoDXcAAAEQ9fxlIfX8ZSEAAAATZWNkc2Etc2hhMi1uaXN0cDUyMQAAAAhuaXN0cDUyMQ +AAAIUEAO2vWsJhXg482IrlDsDH15Ck3r30CIhpXbrvAZIwTll6X3yaSrHU8iSP8SNOZd+D +r6nwgpD1HHxVhyvi1URmoTAYAXBgfWWThBKoijSE+gvVnCUxg6nAUs0Nsk0jsT131JlU6+ +A++r22LuW43fWgSsfB3XzLz3qNFg9HbPy41y3SaA13AAAAQgH4DaftY0e/KsN695VJ06wy +Ve0k2ddxoEsSE15H4lgNHM2iuYKzIqZJOReHRCTff6QGgMYPDqDfFfL1Hc1Ntql0pwAAAA +1yb2JlcnRAYmJzZGV2AQIDBAU= +-----END OPENSSH PRIVATE KEY----- +"; + assert!(matches!( + decode_secret_key(key, None), + Ok(key::KeyPair::EC { + key: ec::PrivateKey::P521(_), + }) + )); + } + + #[test] fn test_fingerprint() { let key = parse_public_key_base64( "AAAAC3NzaC1lZDI1NTE5AAAAILagOJFgwaMNhBWQINinKOXmqS4Gh5NgxgriXwdOoINJ", @@ -526,7 +716,10 @@ QR+u0AypRPmzHnOPAAAAEXJvb3RAMTQwOTExNTQ5NDBkAQ== let path = dir.path().join("known_hosts"); { let mut f = File::create(&path).unwrap(); - f.write(b"[localhost]:13265 ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIJdD7y3aLq454yWBdwLWbieU1ebz9/cu7/QEXn9OIeZJ\n#pijul.org,37.120.161.53 ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIA6rWI3G2sz07DnfFlrouTcysQlj2P+jpNSOEWD9OJ3X\npijul.org,37.120.161.53 ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIA6rWI3G1sz07DnfFlrouTcysQlj2P+jpNSOEWD9OJ3X\n").unwrap(); + f.write_all(b"[localhost]:13265 ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIJdD7y3aLq454yWBdwLWbieU1ebz9/cu7/QEXn9OIeZJ\n").unwrap(); + f.write_all(b"#pijul.org,37.120.161.53 ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIA6rWI3G2sz07DnfFlrouTcysQlj2P+jpNSOEWD9OJ3X\n").unwrap(); + f.write_all(b"pijul.org,37.120.161.53 ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIA6rWI3G1sz07DnfFlrouTcysQlj2P+jpNSOEWD9OJ3X\n").unwrap(); + f.write_all(b"|1|O33ESRMWPVkMYIwJ1Uw+n877jTo=|nuuC5vEqXlEZ/8BXQR7m619W6Ak= ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILIG2T/B0l0gaqj3puu510tu9N1OkQ4znY3LYuEm5zCF\n").unwrap(); } // Valid key, non-standard port. @@ -538,6 +731,15 @@ QR+u0AypRPmzHnOPAAAAEXJvb3RAMTQwOTExNTQ5NDBkAQ== .unwrap(); assert!(check_known_hosts_path(host, port, &hostkey, &path).unwrap()); + // Valid key, hashed. + let host = "example.com"; + let port = 22; + let hostkey = parse_public_key_base64( + "AAAAC3NzaC1lZDI1NTE5AAAAILIG2T/B0l0gaqj3puu510tu9N1OkQ4znY3LYuEm5zCF", + ) + .unwrap(); + assert!(check_known_hosts_path(host, port, &hostkey, &path).unwrap()); + // Valid key, several hosts, port 22 let host = "pijul.org"; let port = 22; @@ -558,7 +760,45 @@ QR+u0AypRPmzHnOPAAAAEXJvb3RAMTQwOTExNTQ5NDBkAQ== } #[test] - #[cfg(feature = "openssl")] + fn test_parse_p256_public_key() { + env_logger::try_init().unwrap_or(()); + let key = "AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBMxBTpMIGvo7CnordO7wP0QQRqpBwUjOLl4eMhfucfE1sjTYyK5wmTl1UqoSDS1PtRVTBdl+0+9pquFb46U7fwg="; + + assert!(matches!( + parse_public_key_base64(key), + Ok(key::PublicKey::EC { + key: ec::PublicKey::P256(_), + }) + )); + } + + #[test] + fn test_parse_p384_public_key() { + env_logger::try_init().unwrap_or(()); + let key = "AAAAE2VjZHNhLXNoYTItbmlzdHAzODQAAAAIbmlzdHAzODQAAABhBBVFgxJxpCaAALZG/S5BHT8/IUQ5mfuKaj7Av9g7Jw59fBEGHfPBz1wFtHGYw5bdLmfVZTIDfogDid5zqJeAKr1AcD06DKTXDzd2EpUjqeLfQ5b3erHuX758fgu/pSDGRA=="; + + assert!(matches!( + parse_public_key_base64(key), + Ok(key::PublicKey::EC { + key: ec::PublicKey::P384(_), + }) + )); + } + + #[test] + fn test_parse_p521_public_key() { + env_logger::try_init().unwrap_or(()); + let key = "AAAAE2VjZHNhLXNoYTItbmlzdHA1MjEAAAAIbmlzdHA1MjEAAACFBAAQepXEpOrzlX22r4E5zEHjhHWeZUe//zaevTanOWRBnnaCGWJFGCdjeAbNOuAmLtXc+HZdJTCZGREeSLSrpJa71QDCgZl0N7DkDUanCpHZJe/DCK6qwtHYbEMn28iLMlGCOrCIa060EyJHbp1xcJx4I1SKj/f/fm3DhhID/do6zyf8Cg=="; + + assert!(matches!( + parse_public_key_base64(key), + Ok(key::PublicKey::EC { + key: ec::PublicKey::P521(_), + }) + )); + } + + #[test] fn test_srhb() { env_logger::try_init().unwrap_or(()); let key = "AAAAB3NzaC1yc2EAAAADAQABAAACAQC0Xtz3tSNgbUQAXem4d+d6hMx7S8Nwm/DOO2AWyWCru+n/+jQ7wz2b5+3oG2+7GbWZNGj8HCc6wJSA3jUsgv1N6PImIWclD14qvoqY3Dea1J0CJgXnnM1xKzBz9C6pDHGvdtySg+yzEO41Xt4u7HFn4Zx5SGuI2NBsF5mtMLZXSi33jCIWVIkrJVd7sZaY8jiqeVZBB/UvkLPWewGVuSXZHT84pNw4+S0Rh6P6zdNutK+JbeuO+5Bav4h9iw4t2sdRkEiWg/AdMoSKmo97Gigq2mKdW12ivnXxz3VfxrCgYJj9WwaUUWSfnAju5SiNly0cTEAN4dJ7yB0mfLKope1kRhPsNaOuUmMUqlu/hBDM/luOCzNjyVJ+0LLB7SV5vOiV7xkVd4KbEGKou8eeCR3yjFazUe/D1pjYPssPL8cJhTSuMc+/UC9zD8yeEZhB9V+vW4NMUR+lh5+XeOzenl65lWYd/nBZXLBbpUMf1AOfbz65xluwCxr2D2lj46iApSIpvE63i3LzFkbGl9GdUiuZJLMFJzOWdhGGc97cB5OVyf8umZLqMHjaImxHEHrnPh1MOVpv87HYJtSBEsN4/omINCMZrk++CRYAIRKRpPKFWV7NQHcvw3m7XLR3KaTYe+0/MINIZwGdou9fLUU3zSd521vDjA/weasH0CyDHq7sZw=="; @@ -567,7 +807,6 @@ QR+u0AypRPmzHnOPAAAAEXJvb3RAMTQwOTExNTQ5NDBkAQ== } #[test] - #[cfg(feature = "openssl")] fn test_nikao() { env_logger::try_init().unwrap_or(()); let key = "-----BEGIN RSA PRIVATE KEY----- @@ -601,7 +840,238 @@ QaChXiDsryJZwsRnruvMRX9nedtqHrgnIsJLTXjppIhGhq5Kg4RQfOU= decode_secret_key(key, None).unwrap(); } - #[cfg(feature = "openssl")] + #[test] + fn test_decode_pkcs8_rsa_secret_key() { + // Generated using: ssh-keygen -t rsa -b 1024 -m pkcs8 -f $file + let key = "-----BEGIN PRIVATE KEY----- +MIICdwIBADANBgkqhkiG9w0BAQEFAASCAmEwggJdAgEAAoGBAKMR20Sc+tU5dS7C +YzIuWnzobqTrIi9JExPTq4GEj01HJ1RJoOoezuiZuIg3iSRRETjXR+pKSzlLEh4v +9VmaDNQMT08EHYCc7NEKXb3c3k/4RNSHtvxKAsyK2ucrvaJGO5GDP7W+yQXpt8Os +KlD8G5LHZJMrZ5m1a+sHYdGzphRXAgMBAAECgYBSG8CjaMOoL3lApSJbdxmbAVIM ++lRJKOtRNWiLG5soVyaHe1dp6z9VwWk4NXZ5cdRRIZ0VbHk6DQG/b3iDuFyybqu3 +M7B40+4N7DCJfoWxALCEDSPQQ/Rp7rQ15YdNahZqe+/c8BHVxHdUZNXvMY8QX8jI +ZmoH8e17tRFKB0SZqQJBANjtPcEo5goaaZlly5VWs8SdNrG/ZM4vKQgpwQmtiNJg +TznqMPBcc8Qk43a6BlPDdn8CrBBjeYRF7qGh0cVdca0CQQDAcTQzF+HfWImqttB0 +dCo+jOqKOovXTTJcp4JUMzgvnMHwQZUJRNQxxqkIrmh/gUwWacSK/yxpLgKlXzBz +msaTAkEAk7VPVISVxxFfEE2pR0HnXJy0TmoFqQOhy+YqhH1+acmciNH3iuNZDJkV +rZVTk5vHxwo5wVsKtk+sArEeFmbfbQJAMbUL5qakkSwtYwsVjP70anO7oTi+Jj6q +Y4RhBZ61RJcZARXviRVeOf02bCeglk6veJqZSc3fist3o3+S5El2QQJBAJjjKA9q +bjFFWPDS9kyrpZL1SOjRIM/Mb0K1hCQd/kfbRTCamqvfuPDQ2A9N40bfBiQFQPph +csKph4+a9f37jyE= +-----END PRIVATE KEY----- +"; + assert!(matches!( + decode_secret_key(key, None), + Ok(key::KeyPair::RSA { .. }) + )); + test_decode_encode_symmetry(key); + } + + #[test] + fn test_decode_pkcs8_p256_secret_key() { + // Generated using: ssh-keygen -t ecdsa -b 256 -m pkcs8 -f $file + let key = "-----BEGIN PRIVATE KEY----- +MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgE0C7/pyJDcZTAgWo +ydj6EE8QkZ91jtGoGmdYAVd7LaqhRANCAATWkGOof7R/PAUuOr2+ZPUgB8rGVvgr +qa92U3p4fkJToKXku5eq/32OBj23YMtz76jO3yfMbtG3l1JWLowPA8tV +-----END PRIVATE KEY----- +"; + assert!(matches!( + decode_secret_key(key, None), + Ok(key::KeyPair::EC { + key: ec::PrivateKey::P256(_) + }) + )); + test_decode_encode_symmetry(key); + } + + #[test] + fn test_decode_pkcs8_p384_secret_key() { + // Generated using: ssh-keygen -t ecdsa -b 384 -m pkcs8 -f $file + let key = "-----BEGIN PRIVATE KEY----- +MIG2AgEAMBAGByqGSM49AgEGBSuBBAAiBIGeMIGbAgEBBDCaqAL30kg+T5BUOYG9 +MrzeDXiUwy9LM8qJGNXiMYou0pVjFZPZT3jAsrUQo47PLQ6hZANiAARuEHbXJBYK +9uyJj4PjT56OHjT2GqMa6i+FTG9vdLtu4OLUkXku+kOuFNjKvEI1JYBrJTpw9kSZ +CI3WfCsQvVjoC7m8qRyxuvR3Rv8gGXR1coQciIoCurLnn9zOFvXCS2Y= +-----END PRIVATE KEY----- +"; + assert!(matches!( + decode_secret_key(key, None), + Ok(key::KeyPair::EC { + key: ec::PrivateKey::P384(_) + }) + )); + test_decode_encode_symmetry(key); + } + + #[test] + fn test_decode_pkcs8_p521_secret_key() { + // Generated using: ssh-keygen -t ecdsa -b 521 -m pkcs8 -f $file + let key = "-----BEGIN PRIVATE KEY----- +MIHuAgEAMBAGByqGSM49AgEGBSuBBAAjBIHWMIHTAgEBBEIB1As9UBUsCiMK7Rzs +EoMgqDM/TK7y7+HgCWzw5UujXvSXCzYCeBgfJszn7dVoJE9G/1ejmpnVTnypdKEu +iIvd4LyhgYkDgYYABAADBCrg7hkomJbCsPMuMcq68ulmo/6Tv8BDS13F8T14v5RN +/0iT/+nwp6CnbBFewMI2TOh/UZNyPpQ8wOFNn9zBmAFCMzkQibnSWK0hrRstY5LT +iaOYDwInbFDsHu8j3TGs29KxyVXMexeV6ROQyXzjVC/quT1R5cOQ7EadE4HvaWhT +Ow== +-----END PRIVATE KEY----- +"; + assert!(matches!( + decode_secret_key(key, None), + Ok(key::KeyPair::EC { + key: ec::PrivateKey::P521(_) + }) + )); + test_decode_encode_symmetry(key); + } + + #[test] + #[cfg(feature = "legacy-ed25519-pkcs8-parser")] + fn test_decode_pkcs8_ed25519_generated_by_russh_0_43() -> Result<(), crate::Error> { + // Generated by russh 0.43 + let key = "-----BEGIN PRIVATE KEY----- +MHMCAQEwBQYDK2VwBEIEQBHw4cXPpGgA+KdvPF5gxrzML+oa3yQk0JzIbWvmqM5H30RyBF8GrOWz +p77UAd3O4PgYzzFcUc79g8yKtbKhzJGhIwMhAN9EcgRfBqzls6e+1AHdzuD4GM8xXFHO/YPMirWy +ocyR + +-----END PRIVATE KEY----- +"; + + assert!(matches!( + decode_secret_key(key, None)?, + key::KeyPair::Ed25519(_) + )); + + let key::KeyPair::Ed25519(inner) = decode_secret_key(key, None)? else { + panic!(); + }; + + assert_eq!( + &inner.to_bytes(), + &[ + 17, 240, 225, 197, 207, 164, 104, 0, 248, 167, 111, 60, 94, 96, 198, 188, 204, 47, + 234, 26, 223, 36, 36, 208, 156, 200, 109, 107, 230, 168, 206, 71 + ] + ); + + Ok(()) + } + + fn test_decode_encode_symmetry(key: &str) { + let original_key_bytes = data_encoding::BASE64_MIME + .decode( + &key.lines() + .filter(|line| !line.starts_with("-----")) + .collect::>() + .join("") + .as_bytes(), + ) + .unwrap(); + let decoded_key = decode_secret_key(key, None).unwrap(); + let encoded_key_bytes = pkcs8::encode_pkcs8(&decoded_key).unwrap(); + assert_eq!(original_key_bytes, encoded_key_bytes); + } + + fn ecdsa_sign_verify(key: &str, public: &str) { + let key = decode_secret_key(key, None).unwrap(); + let buf = b"blabla"; + let sig = key.sign_detached(buf).unwrap(); + // Verify using the provided public key. + { + let public = parse_public_key_base64(public).unwrap(); + assert!(public.verify_detached(buf, sig.as_ref())); + } + // Verify using public key derived from the secret key. + { + let public = key.clone_public_key().unwrap(); + assert!(public.verify_detached(buf, sig.as_ref())); + } + // Sanity check that it uses a different random number. + { + let sig2 = key.sign_detached(buf).unwrap(); + match (sig, sig2) { + ( + key::Signature::ECDSA { + algorithm, + signature, + }, + key::Signature::ECDSA { + algorithm: algorithm2, + signature: signature2, + }, + ) => { + assert_eq!(algorithm, algorithm2); + assert_ne!(signature, signature2); + } + _ => assert!(false), + } + } + // Verify (r, s) = (0, 0) is an invalid signature. (CVE-2022-21449) + { + use crate::encoding::Encoding; + let mut sig = Vec::new(); + sig.extend_ssh_string(&[0]); + sig.extend_ssh_string(&[0]); + let public = key.clone_public_key().unwrap(); + assert_eq!(false, public.verify_detached(buf, &sig)); + } + } + + #[test] + fn test_ecdsa_sha2_nistp256_sign_verify() { + env_logger::try_init().unwrap_or(()); + let key = "-----BEGIN OPENSSH PRIVATE KEY----- +b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAaAAAABNlY2RzYS +1zaGEyLW5pc3RwMjU2AAAACG5pc3RwMjU2AAAAQQRQh23nB1wSlbAwhX3hrbNa35Z6vuY1 +CnEhAjk4FSWR1/tcna7RKCMXdYEiPs5rHr+mMoJxeQxmCd+ny8uIBrg1AAAAqKgQe5KoEH +uSAAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBFCHbecHXBKVsDCF +feGts1rflnq+5jUKcSECOTgVJZHX+1ydrtEoIxd1gSI+zmsev6YygnF5DGYJ36fLy4gGuD +UAAAAgFOgyq4FDOtEe+vBy1O1dqMLjXrKmqcgPpOO3+9cbPM0AAAAKZWNkc2FAdGVzdAEC +AwQFBg== +-----END OPENSSH PRIVATE KEY----- +"; + let public = "AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBFCHbecHXBKVsDCFfeGts1rflnq+5jUKcSECOTgVJZHX+1ydrtEoIxd1gSI+zmsev6YygnF5DGYJ36fLy4gGuDU="; + ecdsa_sign_verify(key, public); + } + + #[test] + fn test_ecdsa_sha2_nistp384_sign_verify() { + env_logger::try_init().unwrap_or(()); + let key = "-----BEGIN OPENSSH PRIVATE KEY----- +b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAiAAAABNlY2RzYS +1zaGEyLW5pc3RwMzg0AAAACG5pc3RwMzg0AAAAYQRC1ed+3MGnknPWE6rdbw9p5f91AJSC +a469EDg5+EkDdVEEN1dWakAtI+gLRlpMotYD0Cso1Nx2MU9nW5fWLBmLtOWU6C1SX6INXB +527U0Ex5AYetNPBIhdTWB1UhbVkxgAAADYiT5XRYk+V0UAAAATZWNkc2Etc2hhMi1uaXN0 +cDM4NAAAAAhuaXN0cDM4NAAAAGEEQtXnftzBp5Jz1hOq3W8PaeX/dQCUgmuOvRA4OfhJA3 +VRBDdXVmpALSPoC0ZaTKLWA9ArKNTcdjFPZ1uX1iwZi7TllOgtUl+iDVwedu1NBMeQGHrT +TwSIXU1gdVIW1ZMYAAAAMH13rmHaaOv7SG4v/e3AV6yY49DzZD8YTzHRS62KDUPB/6t774 +PCeBxYsjjIg5q1FwAAAAplY2RzYUB0ZXN0AQIDBAUG +-----END OPENSSH PRIVATE KEY----- +"; + let public = "AAAAE2VjZHNhLXNoYTItbmlzdHAzODQAAAAIbmlzdHAzODQAAABhBELV537cwaeSc9YTqt1vD2nl/3UAlIJrjr0QODn4SQN1UQQ3V1ZqQC0j6AtGWkyi1gPQKyjU3HYxT2dbl9YsGYu05ZToLVJfog1cHnbtTQTHkBh6008EiF1NYHVSFtWTGA=="; + ecdsa_sign_verify(key, public); + } + + #[test] + fn test_ecdsa_sha2_nistp521_sign_verify() { + env_logger::try_init().unwrap_or(()); + let key = "-----BEGIN OPENSSH PRIVATE KEY----- +b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAArAAAABNlY2RzYS +1zaGEyLW5pc3RwNTIxAAAACG5pc3RwNTIxAAAAhQQBVr19z0rsH1q3nly7RMJBfcHQER5H +oyqEAfX6NnGsa6atBcILGTKYNk/wqf58WabI1XY0ZGsJrx9twIbD6Wu0IcMAlS4MEYNjk7 +/J0FWEfYVKRIRRSK8bzT2uiDxRwmH1ZkQSEE/ghur46O4pA4H++w699LU3alWtDx+bJfx7 +zu4XjHwAAAEQqaEnO6mhJzsAAAATZWNkc2Etc2hhMi1uaXN0cDUyMQAAAAhuaXN0cDUyMQ +AAAIUEAVa9fc9K7B9at55cu0TCQX3B0BEeR6MqhAH1+jZxrGumrQXCCxkymDZP8Kn+fFmm +yNV2NGRrCa8fbcCGw+lrtCHDAJUuDBGDY5O/ydBVhH2FSkSEUUivG809rog8UcJh9WZEEh +BP4Ibq+OjuKQOB/vsOvfS1N2pVrQ8fmyX8e87uF4x8AAAAQgE10hd4g3skdWl4djRX4kE3 +ZgmnWhuwhyxErC5UkMHiEvTOZllxBvefs7XeJqL11pqQIHY4Gb5OQGiCNHiRRjg0egAAAA +1yb2JlcnRAYmJzZGV2AQIDBAU= +-----END OPENSSH PRIVATE KEY----- +"; + let public = "AAAAE2VjZHNhLXNoYTItbmlzdHA1MjEAAAAIbmlzdHA1MjEAAACFBAFWvX3PSuwfWreeXLtEwkF9wdARHkejKoQB9fo2caxrpq0FwgsZMpg2T/Cp/nxZpsjVdjRkawmvH23AhsPpa7QhwwCVLgwRg2OTv8nQVYR9hUpEhFFIrxvNPa6IPFHCYfVmRBIQT+CG6vjo7ikDgf77Dr30tTdqVa0PH5sl/HvO7heMfA=="; + ecdsa_sign_verify(key, public); + } + pub const PKCS8_RSA: &str = "-----BEGIN RSA PRIVATE KEY----- MIIEpAIBAAKCAQEAwBGetHjW+3bDQpVktdemnk7JXgu1NBWUM+ysifYLDBvJ9ttX GNZSyQKA4v/dNr0FhAJ8I9BuOTjYCy1YfKylhl5D/DiSSXFPsQzERMmGgAlYvU2U @@ -632,7 +1102,6 @@ xV/JrzLAwPoKk3bkqys3bUmgo6DxVC/6RmMwPQ0rmpw78kOgEej90g== "; #[test] - #[cfg(feature = "openssl")] fn test_loewenheim() -> Result<(), Error> { env_logger::try_init().unwrap_or(()); let key = "-----BEGIN RSA PRIVATE KEY----- @@ -675,7 +1144,6 @@ KJaj7gc0n6gmKY6r0/Ddufy1JZ6eihBCSJ64RARBXeg2rZpyT+xxhMEZLK5meOeR } #[test] - #[cfg(feature = "openssl")] fn test_o01eg() { env_logger::try_init().unwrap_or(()); @@ -713,14 +1181,12 @@ br8gXU8KyiY9sZVbmplRPF+ar462zcI2kt0a18mr0vbrdqp2eMjb37QDbVBJ+rPE decode_secret_key(key, Some("12345")).unwrap(); } #[test] - #[cfg(feature = "openssl")] fn test_pkcs8() { env_logger::try_init().unwrap_or(()); println!("test"); decode_secret_key(PKCS8_RSA, Some("blabla")).unwrap(); } - #[cfg(feature = "openssl")] const PKCS8_ENCRYPTED: &str = "-----BEGIN ENCRYPTED PRIVATE KEY----- MIIFLTBXBgkqhkiG9w0BBQ0wSjApBgkqhkiG9w0BBQwwHAQITo1O0b8YrS0CAggA MAwGCCqGSIb3DQIJBQAwHQYJYIZIAWUDBAEqBBBtLH4T1KOfo1GGr7salhR8BIIE @@ -753,7 +1219,6 @@ Cog3JMeTrb3LiPHgN6gU2P30MRp6L1j1J/MtlOAr5rux -----END ENCRYPTED PRIVATE KEY-----"; #[test] - #[cfg(feature = "openssl")] fn test_gpg() { env_logger::try_init().unwrap_or(()); let algo = [115, 115, 104, 45, 114, 115, 97]; @@ -786,7 +1251,6 @@ Cog3JMeTrb3LiPHgN6gU2P30MRp6L1j1J/MtlOAr5rux } #[test] - #[cfg(feature = "openssl")] fn test_pkcs8_encrypted() { env_logger::try_init().unwrap_or(()); println!("test"); @@ -794,71 +1258,73 @@ Cog3JMeTrb3LiPHgN6gU2P30MRp6L1j1J/MtlOAr5rux } #[cfg(unix)] - fn test_client_agent(key: key::KeyPair) { + async fn test_client_agent(key: key::KeyPair) -> Result<(), Box> { env_logger::try_init().unwrap_or(()); - use std::process::{Command, Stdio}; - let dir = tempdir::TempDir::new("russh").unwrap(); + use std::process::Stdio; + + let dir = tempdir::TempDir::new("russh")?; let agent_path = dir.path().join("agent"); - let mut agent = Command::new("ssh-agent") + let mut agent = tokio::process::Command::new("ssh-agent") .arg("-a") .arg(&agent_path) .arg("-D") .stdout(Stdio::null()) .stderr(Stdio::null()) - .spawn() - .expect("failed to execute process"); - std::thread::sleep(std::time::Duration::from_millis(10)); - let rt = tokio::runtime::Runtime::new().unwrap(); - rt.block_on(async move { - let public = key.clone_public_key()?; - let stream = tokio::net::UnixStream::connect(&agent_path).await?; - let mut client = agent::client::AgentClient::connect(stream); - client.add_identity(&key, &[]).await?; - client.request_identities().await?; - let buf = russh_cryptovec::CryptoVec::from_slice(b"blabla"); - let len = buf.len(); - let (_, buf) = client.sign_request(&public, buf).await; - let buf = buf?; - let (a, b) = buf.split_at(len); - match key { - key::KeyPair::Ed25519 { .. } => { - let sig = &b[b.len() - 64..]; - assert!(public.verify_detached(a, sig)); - } - #[cfg(feature = "openssl")] - _ => {} + .spawn()?; + + // Wait for the socket to be created + while agent_path.canonicalize().is_err() { + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + } + + let public = key.clone_public_key()?; + let stream = tokio::net::UnixStream::connect(&agent_path).await?; + let mut client = agent::client::AgentClient::connect(stream); + client.add_identity(&key, &[]).await?; + client.request_identities().await?; + let buf = russh_cryptovec::CryptoVec::from_slice(b"blabla"); + let len = buf.len(); + let (_, buf) = client.sign_request(&public, buf).await; + let buf = buf?; + let (a, b) = buf.split_at(len); + match key { + key::KeyPair::Ed25519 { .. } => { + let sig = &b[b.len() - 64..]; + assert!(public.verify_detached(a, sig)); } - Ok::<(), Error>(()) - }) - .unwrap(); - agent.kill().unwrap(); - agent.wait().unwrap(); + key::KeyPair::EC { .. } => {} + _ => {} + } + + agent.kill().await?; + agent.wait().await?; + + Ok(()) } - #[test] + #[tokio::test] #[cfg(unix)] - fn test_client_agent_ed25519() { + async fn test_client_agent_ed25519() { let key = decode_secret_key(ED25519_KEY, Some("blabla")).unwrap(); - test_client_agent(key) + test_client_agent(key).await.expect("ssh-agent test failed") } - #[test] - #[cfg(feature = "openssl")] - fn test_client_agent_rsa() { + #[tokio::test] + #[cfg(unix)] + async fn test_client_agent_rsa() { let key = decode_secret_key(PKCS8_ENCRYPTED, Some("blabla")).unwrap(); - test_client_agent(key) + test_client_agent(key).await.expect("ssh-agent test failed") } - #[test] - #[cfg(feature = "openssl")] - fn test_client_agent_openssh_rsa() { + #[tokio::test] + #[cfg(unix)] + async fn test_client_agent_openssh_rsa() { let key = decode_secret_key(RSA_KEY, None).unwrap(); - test_client_agent(key) + test_client_agent(key).await.expect("ssh-agent test failed") } #[test] #[cfg(unix)] - #[cfg(feature = "openssl")] fn test_agent() { env_logger::try_init().unwrap_or(()); let dir = tempdir::TempDir::new("russh").unwrap(); diff --git a/russh-keys/src/protocol.rs b/russh-keys/src/protocol.rs new file mode 100644 index 00000000..ff9a2933 --- /dev/null +++ b/russh-keys/src/protocol.rs @@ -0,0 +1,87 @@ +use std::borrow::Cow; + +use crate::encoding::{Encoding, Position, SshRead, SshWrite}; +use crate::key::zeroize_cow; + +type Result = std::result::Result; + +/// SSH RSA public key. +pub struct RsaPublicKey<'a> { + /// `e`: RSA public exponent. + pub public_exponent: Cow<'a, [u8]>, + /// `n`: RSA modulus. + pub modulus: Cow<'a, [u8]>, +} + +impl<'a> SshRead<'a> for RsaPublicKey<'a> { + fn read_ssh(pos: &mut Position<'a>) -> Result { + Ok(Self { + public_exponent: Cow::Borrowed(pos.read_mpint()?), + modulus: Cow::Borrowed(pos.read_mpint()?), + }) + } +} + +impl SshWrite for RsaPublicKey<'_> { + fn write_ssh(&self, encoder: &mut E) { + encoder.extend_ssh_mpint(&self.public_exponent); + encoder.extend_ssh_mpint(&self.modulus); + } +} + +/// SSH RSA private key. +pub struct RsaPrivateKey<'a> { + /// RSA public key. + pub public_key: RsaPublicKey<'a>, + /// `d`: RSA private exponent. + pub private_exponent: Cow<'a, [u8]>, + /// CRT coefficient: `(inverse of q) mod p`. + pub coefficient: Cow<'a, [u8]>, + /// `p`: first prime factor of `n`. + pub prime1: Cow<'a, [u8]>, + /// `q`: Second prime factor of `n`. + pub prime2: Cow<'a, [u8]>, + /// Comment. + pub comment: Cow<'a, [u8]>, +} + +impl<'a> SshRead<'a> for RsaPrivateKey<'a> { + fn read_ssh(pos: &mut Position<'a>) -> Result { + Ok(Self { + // Note the field order. + public_key: RsaPublicKey { + modulus: Cow::Borrowed(pos.read_mpint()?), + public_exponent: Cow::Borrowed(pos.read_mpint()?), + }, + private_exponent: Cow::Borrowed(pos.read_mpint()?), + coefficient: Cow::Borrowed(pos.read_mpint()?), + prime1: Cow::Borrowed(pos.read_mpint()?), + prime2: Cow::Borrowed(pos.read_mpint()?), + comment: Cow::Borrowed(pos.read_string()?), + }) + } +} + +impl SshWrite for RsaPrivateKey<'_> { + fn write_ssh(&self, encoder: &mut E) { + // Note the field order. + encoder.extend_ssh_mpint(&self.public_key.modulus); + encoder.extend_ssh_mpint(&self.public_key.public_exponent); + encoder.extend_ssh_mpint(&self.private_exponent); + encoder.extend_ssh_mpint(&self.coefficient); + encoder.extend_ssh_mpint(&self.prime1); + encoder.extend_ssh_mpint(&self.prime2); + encoder.extend_ssh_string(&self.comment); + } +} + +impl Drop for RsaPrivateKey<'_> { + fn drop(&mut self) { + // Private parts only. + zeroize_cow(&mut self.private_exponent); + zeroize_cow(&mut self.coefficient); + zeroize_cow(&mut self.prime1); + zeroize_cow(&mut self.prime2); + zeroize_cow(&mut self.comment); + } +} diff --git a/russh-keys/src/signature.rs b/russh-keys/src/signature.rs index 712139c2..7bf05c21 100644 --- a/russh-keys/src/signature.rs +++ b/russh-keys/src/signature.rs @@ -18,6 +18,14 @@ pub enum Signature { Ed25519(SignatureBytes), /// An RSA signature RSA { hash: SignatureHash, bytes: Vec }, + /// An ECDSA signature + ECDSA { + /// Algorithm name defined in RFC 5656 section 3.1.2, in the form of + /// `"ecdsa-sha2-[identifier]"`. + algorithm: &'static str, + /// Signature blob defined in RFC 5656 section 3.1.2. + signature: Vec, + }, } impl Signature { @@ -48,7 +56,19 @@ impl Signature { .write_u32::((t.len() + bytes.len() + 8) as u32) .unwrap(); bytes_.extend_ssh_string(t); - bytes_.extend_ssh_string(&bytes[..]); + bytes_.extend_ssh_string(bytes); + } + Signature::ECDSA { + algorithm, + signature, + } => { + let algorithm = algorithm.as_bytes(); + #[allow(clippy::unwrap_used)] // Vec<>.write_all can't fail + bytes_ + .write_u32::((algorithm.len() + signature.len() + 8) as u32) + .unwrap(); + bytes_.extend_ssh_string(algorithm); + bytes_.extend_ssh_string(signature); } } data_encoding::BASE64_NOPAD.encode(&bytes_[..]) @@ -80,6 +100,18 @@ impl Signature { hash: SignatureHash::SHA1, bytes: bytes.to_vec(), }), + crate::KEYTYPE_ECDSA_SHA2_NISTP256 => Ok(Signature::ECDSA { + algorithm: crate::ECDSA_SHA2_NISTP256, + signature: bytes.to_vec(), + }), + crate::KEYTYPE_ECDSA_SHA2_NISTP384 => Ok(Signature::ECDSA { + algorithm: crate::ECDSA_SHA2_NISTP384, + signature: bytes.to_vec(), + }), + crate::KEYTYPE_ECDSA_SHA2_NISTP521 => Ok(Signature::ECDSA { + algorithm: crate::ECDSA_SHA2_NISTP521, + signature: bytes.to_vec(), + }), _ => Err(Error::UnknownSignatureType { sig_type: std::str::from_utf8(typ).unwrap_or("").to_string(), }), @@ -92,6 +124,7 @@ impl AsRef<[u8]> for Signature { match *self { Signature::Ed25519(ref signature) => &signature.0, Signature::RSA { ref bytes, .. } => &bytes[..], + Signature::ECDSA { ref signature, .. } => &signature[..], } } } diff --git a/russh/Cargo.toml b/russh/Cargo.toml index 8f792f79..fecdf24d 100644 --- a/russh/Cargo.toml +++ b/russh/Cargo.toml @@ -3,48 +3,57 @@ authors = ["Pierre-Étienne Meunier "] description = "A client and server SSH library." documentation = "https://docs.rs/russh" edition = "2018" -homepage = "https://pijul.org/russh" +homepage = "https://github.com/warp-tech/russh" keywords = ["ssh"] license = "Apache-2.0" name = "russh" readme = "../README.md" repository = "https://github.com/warp-tech/russh" -version = "0.38.0-beta.1" -rust-version = "1.60" +version = "0.46.0-beta.4" +rust-version = "1.65" [features] default = ["flate2"] openssl = ["russh-keys/openssl", "dep:openssl"] vendored-openssl = ["openssl/vendored", "russh-keys/vendored-openssl"] +legacy-ed25519-pkcs8-parser = ["russh-keys/legacy-ed25519-pkcs8-parser"] [dependencies] -aes = "0.8" +aes = { workspace = true } aes-gcm = "0.10" -async-trait = "0.1" +cbc = { version = "0.1" } +async-trait = { workspace = true } bitflags = "2.0" -byteorder = "1.3" +byteorder = { workspace = true } chacha20 = "0.9" -curve25519-dalek = "4.0" -poly1305 = "0.8" ctr = "0.9" -digest = "0.10" +curve25519-dalek = "4.1.3" +digest = { workspace = true } +elliptic-curve = { version = "0.13", features = ["ecdh"] } flate2 = { version = "1.0", optional = true } -futures = "0.3" +futures = { workspace = true } generic-array = "0.14" -hmac = "0.12" -log = "0.4" -once_cell = "1.13" -openssl = { version = "0.10", optional = true } -rand = "0.8" -russh-cryptovec = { version = "0.7.0", path = "../cryptovec" } -russh-keys = { version = "0.37.1", path = "../russh-keys" } -sha1 = "0.10" -sha2 = "0.10" hex-literal = "0.4" +hmac = { workspace = true } +log = { workspace = true } num-bigint = { version = "0.4", features = ["rand"] } +once_cell = "1.13" +openssl = { workspace = true, optional = true } +p256 = { version = "0.13", features = ["ecdh"] } +p384 = { version = "0.13", features = ["ecdh"] } +p521 = { version = "0.13", features = ["ecdh"] } +poly1305 = "0.8" +rand = { workspace = true } +rand_core = { version = "0.6.4", features = ["getrandom"] } +russh-cryptovec = { version = "0.7.0", path = "../cryptovec" } +russh-keys = { version = "0.46.0-beta.2", path = "../russh-keys" } +sha1 = { workspace = true } +sha2 = { workspace = true } +ssh-encoding = { workspace = true } +ssh-key = { workspace = true } subtle = "2.4" -thiserror = "1.0" -tokio = { version = "1.17.0", features = [ +thiserror = { workspace = true } +tokio = { workspace = true, features = [ "io-util", "rt-multi-thread", "time", @@ -53,12 +62,14 @@ tokio = { version = "1.17.0", features = [ "macros", "process", ] } -tokio-util = "0.7" +des = "0.8.1" [dev-dependencies] anyhow = "1.0" env_logger = "0.10" +clap = { version = "3.2", features = ["derive"] } tokio = { version = "1.17.0", features = [ + "io-std", "io-util", "rt-multi-thread", "time", @@ -66,7 +77,12 @@ tokio = { version = "1.17.0", features = [ "sync", "macros", ] } -russh-sftp = "1.1" +russh-sftp = "2.0.0-beta.2" +rand = "0.8.5" +shell-escape = "0.1" +tokio-fd = "0.3" +termion = "2" +ratatui = "0.26.0" [package.metadata.docs.rs] features = ["openssl"] diff --git a/russh/examples/client.rs b/russh/examples/client.rs deleted file mode 100644 index 6948835f..00000000 --- a/russh/examples/client.rs +++ /dev/null @@ -1,70 +0,0 @@ -use std::sync::Arc; - -use anyhow::Context; -use async_trait::async_trait; -use russh::*; -use russh_keys::*; - -struct Client {} - -#[async_trait] -impl client::Handler for Client { - type Error = russh::Error; - - async fn check_server_key( - self, - server_public_key: &key::PublicKey, - ) -> Result<(Self, bool), Self::Error> { - println!("check_server_key: {:?}", server_public_key); - Ok((self, true)) - } -} - -#[tokio::main] -async fn main() { - env_logger::init(); - let config = russh::client::Config::default(); - let config = Arc::new(config); - let sh = Client {}; - - let mut agent = russh_keys::agent::client::AgentClient::connect_env() - .await - .unwrap(); - let mut identities = agent.request_identities().await.unwrap(); - let mut session = russh::client::connect(config, ("127.0.0.1", 2200), sh) - .await - .unwrap(); - let (_, auth_res) = session - .authenticate_future("pe", identities.pop().unwrap(), agent) - .await; - let auth_res = auth_res.unwrap(); - println!("=== auth: {}", auth_res); - let mut channel = session - .channel_open_direct_tcpip("localhost", 8000, "localhost", 3333) - .await - .unwrap(); - // let mut channel = session.channel_open_session().await.unwrap(); - println!("=== after open channel"); - let data = b"GET /les_affames.mkv HTTP/1.1\nUser-Agent: curl/7.68.0\nAccept: */*\nConnection: close\n\n"; - channel.data(&data[..]).await.unwrap(); - let mut f = std::fs::File::create("les_affames.mkv").unwrap(); - while let Some(msg) = channel.wait().await { - use std::io::Write; - match msg { - russh::ChannelMsg::Data { ref data } => { - f.write_all(data).unwrap(); - } - russh::ChannelMsg::Eof => { - f.flush().unwrap(); - break; - } - _ => {} - } - } - session - .disconnect(Disconnect::ByApplication, "", "English") - .await - .unwrap(); - let res = session.await.context("session await"); - println!("{:#?}", res); -} diff --git a/russh/examples/client_exec_interactive.rs b/russh/examples/client_exec_interactive.rs new file mode 100644 index 00000000..7a428301 --- /dev/null +++ b/russh/examples/client_exec_interactive.rs @@ -0,0 +1,227 @@ +/// +/// Run this example with: +/// cargo run --example client_exec_interactive -- -k +/// +use std::convert::TryFrom; +use std::env; +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use std::time::Duration; + +use anyhow::Result; +use async_trait::async_trait; +use clap::Parser; +use log::info; +use russh::keys::*; +use russh::*; +use termion::raw::IntoRawMode; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::ToSocketAddrs; + +#[tokio::main] +async fn main() -> Result<()> { + env_logger::builder() + .filter_level(log::LevelFilter::Info) + .init(); + + // CLI options are defined later in this file + let cli = Cli::parse(); + + info!("Connecting to {}:{}", cli.host, cli.port); + info!("Key path: {:?}", cli.private_key); + info!("OpenSSH Certificate path: {:?}", cli.openssh_certificate); + + // Session is a wrapper around a russh client, defined down below + let mut ssh = Session::connect( + cli.private_key, + cli.username.unwrap_or("root".to_string()), + cli.openssh_certificate, + (cli.host, cli.port), + ) + .await?; + info!("Connected"); + + let code = { + // We're using `termion` to put the terminal into raw mode, so that we can + // display the output of interactive applications correctly + let _raw_term = std::io::stdout().into_raw_mode()?; + ssh.call( + &cli.command + .into_iter() + .map(|x| shell_escape::escape(x.into())) // arguments are escaped manually since the SSH protocol doesn't support quoting + .collect::>() + .join(" "), + ) + .await? + }; + + println!("Exitcode: {:?}", code); + ssh.close().await?; + Ok(()) +} + +struct Client {} + +// More SSH event handlers +// can be defined in this trait +// In this example, we're only using Channel, so these aren't needed. +#[async_trait] +impl client::Handler for Client { + type Error = russh::Error; + + async fn check_server_key( + &mut self, + _server_public_key: &key::PublicKey, + ) -> Result { + Ok(true) + } +} + +/// This struct is a convenience wrapper +/// around a russh client +/// that handles the input/output event loop +pub struct Session { + session: client::Handle, +} + +impl Session { + async fn connect, A: ToSocketAddrs>( + key_path: P, + user: impl Into, + openssh_cert_path: Option

, + addrs: A, + ) -> Result { + let key_pair = load_secret_key(key_path, None)?; + + // load ssh certificate + let mut openssh_cert = None; + if openssh_cert_path.is_some() { + openssh_cert = Some(load_openssh_certificate(openssh_cert_path.unwrap())?); + } + + let config = client::Config { + inactivity_timeout: Some(Duration::from_secs(5)), + ..<_>::default() + }; + + let config = Arc::new(config); + let sh = Client {}; + + let mut session = client::connect(config, addrs, sh).await?; + + // use publickey authentication, with or without certificate + if openssh_cert.is_none() { + let auth_res = session + .authenticate_publickey(user, Arc::new(key_pair)) + .await?; + + if !auth_res { + anyhow::bail!("Authentication (with publickey) failed"); + } + } else { + let auth_res = session + .authenticate_openssh_cert(user, Arc::new(key_pair), openssh_cert.unwrap()) + .await?; + + if !auth_res { + anyhow::bail!("Authentication (with publickey+cert) failed"); + } + } + + Ok(Self { session }) + } + + async fn call(&mut self, command: &str) -> Result { + let mut channel = self.session.channel_open_session().await?; + + // This example doesn't terminal resizing after the connection is established + let (w, h) = termion::terminal_size()?; + + // Request an interactive PTY from the server + channel + .request_pty( + false, + &env::var("TERM").unwrap_or("xterm".into()), + w as u32, + h as u32, + 0, + 0, + &[], // ideally you want to pass the actual terminal modes here + ) + .await?; + channel.exec(true, command).await?; + + let code; + let mut stdin = tokio_fd::AsyncFd::try_from(0)?; + let mut stdout = tokio_fd::AsyncFd::try_from(1)?; + let mut buf = vec![0; 1024]; + let mut stdin_closed = false; + + loop { + // Handle one of the possible events: + tokio::select! { + // There's terminal input available from the user + r = stdin.read(&mut buf), if !stdin_closed => { + match r { + Ok(0) => { + stdin_closed = true; + channel.eof().await?; + }, + // Send it to the server + Ok(n) => channel.data(&buf[..n]).await?, + Err(e) => return Err(e.into()), + }; + }, + // There's an event available on the session channel + Some(msg) = channel.wait() => { + match msg { + // Write data to the terminal + ChannelMsg::Data { ref data } => { + stdout.write_all(data).await?; + stdout.flush().await?; + } + // The command has returned an exit code + ChannelMsg::ExitStatus { exit_status } => { + code = exit_status; + if !stdin_closed { + channel.eof().await?; + } + break; + } + _ => {} + } + }, + } + } + Ok(code) + } + + async fn close(&mut self) -> Result<()> { + self.session + .disconnect(Disconnect::ByApplication, "", "English") + .await?; + Ok(()) + } +} + +#[derive(clap::Parser)] +#[clap(trailing_var_arg = true)] +pub struct Cli { + #[clap(index = 1)] + host: String, + + #[clap(long, short, default_value_t = 22)] + port: u16, + + #[clap(long, short)] + username: Option, + + #[clap(long, short = 'k')] + private_key: PathBuf, + + #[clap(long, short = 'o')] + openssh_certificate: Option, + + #[clap(multiple = true, index = 2, required = true)] + command: Vec, +} diff --git a/russh/examples/client_exec_simple.rs b/russh/examples/client_exec_simple.rs new file mode 100644 index 00000000..50787575 --- /dev/null +++ b/russh/examples/client_exec_simple.rs @@ -0,0 +1,158 @@ +/// +/// Run this example with: +/// cargo run --example client_exec_simple -- -k +/// +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use std::time::Duration; + +use anyhow::Result; +use async_trait::async_trait; +use clap::Parser; +use log::info; +use russh::keys::*; +use russh::*; +use tokio::io::AsyncWriteExt; +use tokio::net::ToSocketAddrs; + +#[tokio::main] +async fn main() -> Result<()> { + env_logger::builder() + .filter_level(log::LevelFilter::Info) + .init(); + + // CLI options are defined later in this file + let cli = Cli::parse(); + + info!("Connecting to {}:{}", cli.host, cli.port); + info!("Key path: {:?}", cli.private_key); + + // Session is a wrapper around a russh client, defined down below + let mut ssh = Session::connect( + cli.private_key, + cli.username.unwrap_or("root".to_string()), + (cli.host, cli.port), + ) + .await?; + info!("Connected"); + + let code = ssh + .call( + &cli.command + .into_iter() + .map(|x| shell_escape::escape(x.into())) // arguments are escaped manually since the SSH protocol doesn't support quoting + .collect::>() + .join(" "), + ) + .await?; + + println!("Exitcode: {:?}", code); + ssh.close().await?; + Ok(()) +} + +struct Client {} + +// More SSH event handlers +// can be defined in this trait +// In this example, we're only using Channel, so these aren't needed. +#[async_trait] +impl client::Handler for Client { + type Error = russh::Error; + + async fn check_server_key( + &mut self, + _server_public_key: &key::PublicKey, + ) -> Result { + Ok(true) + } +} + +/// This struct is a convenience wrapper +/// around a russh client +pub struct Session { + session: client::Handle, +} + +impl Session { + async fn connect, A: ToSocketAddrs>( + key_path: P, + user: impl Into, + addrs: A, + ) -> Result { + let key_pair = load_secret_key(key_path, None)?; + let config = client::Config { + inactivity_timeout: Some(Duration::from_secs(5)), + ..<_>::default() + }; + + let config = Arc::new(config); + let sh = Client {}; + + let mut session = client::connect(config, addrs, sh).await?; + let auth_res = session + .authenticate_publickey(user, Arc::new(key_pair)) + .await?; + + if !auth_res { + anyhow::bail!("Authentication failed"); + } + + Ok(Self { session }) + } + + async fn call(&mut self, command: &str) -> Result { + let mut channel = self.session.channel_open_session().await?; + channel.exec(true, command).await?; + + let mut code = None; + let mut stdout = tokio::io::stdout(); + + loop { + // There's an event available on the session channel + let Some(msg) = channel.wait().await else { + break; + }; + match msg { + // Write data to the terminal + ChannelMsg::Data { ref data } => { + stdout.write_all(data).await?; + stdout.flush().await?; + } + // The command has returned an exit code + ChannelMsg::ExitStatus { exit_status } => { + code = Some(exit_status); + // cannot leave the loop immediately, there might still be more data to receive + } + _ => {} + } + } + Ok(code.expect("program did not exit cleanly")) + } + + async fn close(&mut self) -> Result<()> { + self.session + .disconnect(Disconnect::ByApplication, "", "English") + .await?; + Ok(()) + } +} + +#[derive(clap::Parser)] +#[clap(trailing_var_arg = true)] +pub struct Cli { + #[clap(index = 1)] + host: String, + + #[clap(long, short, default_value_t = 22)] + port: u16, + + #[clap(long, short)] + username: Option, + + #[clap(long, short = 'k')] + private_key: PathBuf, + + #[clap(multiple = true, index = 2, required = true)] + command: Vec, +} diff --git a/russh/examples/echoserver.rs b/russh/examples/echoserver.rs index 64dbae92..96627af4 100644 --- a/russh/examples/echoserver.rs +++ b/russh/examples/echoserver.rs @@ -2,9 +2,9 @@ use std::collections::HashMap; use std::sync::Arc; use async_trait::async_trait; -use russh::server::{Msg, Session}; +use russh::keys::*; +use russh::server::{Msg, Server as _, Session}; use russh::*; -use russh_keys::*; use tokio::sync::Mutex; #[tokio::main] @@ -21,13 +21,11 @@ async fn main() { ..Default::default() }; let config = Arc::new(config); - let sh = Server { + let mut sh = Server { clients: Arc::new(Mutex::new(HashMap::new())), id: 0, }; - russh::server::run(config, ("0.0.0.0", 2222), sh) - .await - .unwrap(); + sh.run_on_address(config, ("0.0.0.0", 2222)).await.unwrap(); } #[derive(Clone)] @@ -61,54 +59,54 @@ impl server::Handler for Server { type Error = anyhow::Error; async fn channel_open_session( - self, + &mut self, channel: Channel, - session: Session, - ) -> Result<(Self, bool, Session), Self::Error> { + session: &mut Session, + ) -> Result { { let mut clients = self.clients.lock().await; clients.insert((self.id, channel.id()), session.handle()); } - Ok((self, true, session)) + Ok(true) } async fn auth_publickey( - self, + &mut self, _: &str, _: &key::PublicKey, - ) -> Result<(Self, server::Auth), Self::Error> { - Ok((self, server::Auth::Accept)) + ) -> Result { + Ok(server::Auth::Accept) } async fn data( - mut self, + &mut self, channel: ChannelId, data: &[u8], - mut session: Session, - ) -> Result<(Self, Session), Self::Error> { + session: &mut Session, + ) -> Result<(), Self::Error> { let data = CryptoVec::from(format!("Got data: {}\r\n", String::from_utf8_lossy(data))); self.post(data.clone()).await; session.data(channel, data); - Ok((self, session)) + Ok(()) } async fn tcpip_forward( - self, + &mut self, address: &str, port: &mut u32, - session: Session, - ) -> Result<(Self, bool, Session), Self::Error> { + session: &mut Session, + ) -> Result { let handle = session.handle(); let address = address.to_string(); let port = *port; tokio::spawn(async move { - let mut channel = handle + let channel = handle .channel_open_forwarded_tcpip(address, port, "1.2.3.4", 1234) .await .unwrap(); let _ = channel.data(&b"Hello from a forwarded port"[..]).await; let _ = channel.eof().await; }); - Ok((self, true, session)) + Ok(true) } } diff --git a/russh/examples/ratatui_app.rs b/russh/examples/ratatui_app.rs new file mode 100644 index 00000000..f42439cf --- /dev/null +++ b/russh/examples/ratatui_app.rs @@ -0,0 +1,212 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use async_trait::async_trait; +use ratatui::backend::CrosstermBackend; +use ratatui::layout::Rect; +use ratatui::style::{Color, Style}; +use ratatui::widgets::{Block, Borders, Clear, Paragraph}; +use ratatui::Terminal; +use russh::keys::key::PublicKey; +use russh::server::*; +use russh::{Channel, ChannelId}; +use tokio::sync::Mutex; + +type SshTerminal = Terminal>; + +struct App { + pub counter: usize, +} + +impl App { + pub fn new() -> App { + Self { counter: 0 } + } +} + +#[derive(Clone)] +struct TerminalHandle { + handle: Handle, + // The sink collects the data which is finally flushed to the handle. + sink: Vec, + channel_id: ChannelId, +} + +// The crossterm backend writes to the terminal handle. +impl std::io::Write for TerminalHandle { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + self.sink.extend_from_slice(buf); + Ok(buf.len()) + } + + fn flush(&mut self) -> std::io::Result<()> { + let handle = self.handle.clone(); + let channel_id = self.channel_id; + let data = self.sink.clone().into(); + futures::executor::block_on(async move { + let result = handle.data(channel_id, data).await; + if result.is_err() { + eprintln!("Failed to send data: {:?}", result); + } + }); + + self.sink.clear(); + Ok(()) + } +} + +#[derive(Clone)] +struct AppServer { + clients: Arc>>, + id: usize, +} + +impl AppServer { + pub fn new() -> Self { + Self { + clients: Arc::new(Mutex::new(HashMap::new())), + id: 0, + } + } + + pub async fn run(&mut self) -> Result<(), anyhow::Error> { + let clients = self.clients.clone(); + tokio::spawn(async move { + loop { + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + + for (_, (terminal, app)) in clients.lock().await.iter_mut() { + app.counter += 1; + + terminal + .draw(|f| { + let size = f.size(); + f.render_widget(Clear, size); + let style = match app.counter % 3 { + 0 => Style::default().fg(Color::Red), + 1 => Style::default().fg(Color::Green), + _ => Style::default().fg(Color::Blue), + }; + let paragraph = Paragraph::new(format!("Counter: {}", app.counter)) + .alignment(ratatui::layout::Alignment::Center) + .style(style); + let block = Block::default() + .title("Press 'c' to reset the counter!") + .borders(Borders::ALL); + f.render_widget(paragraph.block(block), size); + }) + .unwrap(); + } + } + }); + + let config = Config { + inactivity_timeout: Some(std::time::Duration::from_secs(3600)), + auth_rejection_time: std::time::Duration::from_secs(3), + auth_rejection_time_initial: Some(std::time::Duration::from_secs(0)), + keys: vec![russh_keys::key::KeyPair::generate_ed25519().unwrap()], + ..Default::default() + }; + + self.run_on_address(Arc::new(config), ("0.0.0.0", 2222)) + .await?; + Ok(()) + } +} + +impl Server for AppServer { + type Handler = Self; + fn new_client(&mut self, _: Option) -> Self { + let s = self.clone(); + self.id += 1; + s + } +} + +#[async_trait] +impl Handler for AppServer { + type Error = anyhow::Error; + + async fn channel_open_session( + &mut self, + channel: Channel, + session: &mut Session, + ) -> Result { + { + let mut clients = self.clients.lock().await; + let terminal_handle = TerminalHandle { + handle: session.handle(), + sink: Vec::new(), + channel_id: channel.id(), + }; + + let backend = CrosstermBackend::new(terminal_handle.clone()); + let terminal = Terminal::new(backend)?; + let app = App::new(); + + clients.insert(self.id, (terminal, app)); + } + + Ok(true) + } + + async fn auth_publickey(&mut self, _: &str, _: &PublicKey) -> Result { + Ok(Auth::Accept) + } + + async fn data( + &mut self, + channel: ChannelId, + data: &[u8], + session: &mut Session, + ) -> Result<(), Self::Error> { + match data { + // Pressing 'q' closes the connection. + b"q" => { + self.clients.lock().await.remove(&self.id); + session.close(channel); + } + // Pressing 'c' resets the counter for the app. + // Only the client with the id sees the counter reset. + b"c" => { + let mut clients = self.clients.lock().await; + let (_, app) = clients.get_mut(&self.id).unwrap(); + app.counter = 0; + } + _ => {} + } + + Ok(()) + } + + /// The client's window size has changed. + async fn window_change_request( + &mut self, + _: ChannelId, + col_width: u32, + row_height: u32, + _: u32, + _: u32, + _: &mut Session, + ) -> Result<(), Self::Error> { + { + let mut clients = self.clients.lock().await; + let (terminal, _) = clients.get_mut(&self.id).unwrap(); + let rect = Rect { + x: 0, + y: 0, + width: col_width as u16, + height: row_height as u16, + }; + terminal.resize(rect)?; + } + + Ok(()) + } +} + +#[tokio::main] +async fn main() { + let mut server = AppServer::new(); + server.run().await.expect("Failed running server"); +} diff --git a/russh/examples/ratatui_shared_app.rs b/russh/examples/ratatui_shared_app.rs new file mode 100644 index 00000000..174a6d6e --- /dev/null +++ b/russh/examples/ratatui_shared_app.rs @@ -0,0 +1,211 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use async_trait::async_trait; +use ratatui::backend::CrosstermBackend; +use ratatui::layout::Rect; +use ratatui::style::{Color, Style}; +use ratatui::widgets::{Block, Borders, Clear, Paragraph}; +use ratatui::Terminal; +use russh::keys::key::PublicKey; +use russh::server::*; +use russh::{Channel, ChannelId}; +use tokio::sync::Mutex; + +type SshTerminal = Terminal>; + +struct App { + pub counter: usize, +} + +impl App { + pub fn new() -> App { + Self { counter: 0 } + } +} + +#[derive(Clone)] +struct TerminalHandle { + handle: Handle, + // The sink collects the data which is finally flushed to the handle. + sink: Vec, + channel_id: ChannelId, +} + +// The crossterm backend writes to the terminal handle. +impl std::io::Write for TerminalHandle { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + self.sink.extend_from_slice(buf); + Ok(buf.len()) + } + + fn flush(&mut self) -> std::io::Result<()> { + let handle = self.handle.clone(); + let channel_id = self.channel_id; + let data = self.sink.clone().into(); + futures::executor::block_on(async move { + let result = handle.data(channel_id, data).await; + if result.is_err() { + eprintln!("Failed to send data: {:?}", result); + } + }); + + self.sink.clear(); + Ok(()) + } +} + +#[derive(Clone)] +struct AppServer { + clients: Arc>>, + id: usize, + app: Arc>, +} + +impl AppServer { + pub fn new() -> Self { + Self { + clients: Arc::new(Mutex::new(HashMap::new())), + id: 0, + app: Arc::new(Mutex::new(App::new())), + } + } + + pub async fn run(&mut self) -> Result<(), anyhow::Error> { + let app = self.app.clone(); + let clients = self.clients.clone(); + tokio::spawn(async move { + loop { + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + app.lock().await.counter += 1; + let counter = app.lock().await.counter; + for (_, terminal) in clients.lock().await.iter_mut() { + terminal + .draw(|f| { + let size = f.size(); + f.render_widget(Clear, size); + let style = match counter % 3 { + 0 => Style::default().fg(Color::Red), + 1 => Style::default().fg(Color::Green), + _ => Style::default().fg(Color::Blue), + }; + let paragraph = Paragraph::new(format!("Counter: {counter}")) + .alignment(ratatui::layout::Alignment::Center) + .style(style); + let block = Block::default() + .title("Press 'c' to reset the counter!") + .borders(Borders::ALL); + f.render_widget(paragraph.block(block), size); + }) + .unwrap(); + } + } + }); + + let config = Config { + inactivity_timeout: Some(std::time::Duration::from_secs(3600)), + auth_rejection_time: std::time::Duration::from_secs(3), + auth_rejection_time_initial: Some(std::time::Duration::from_secs(0)), + keys: vec![russh_keys::key::KeyPair::generate_ed25519().unwrap()], + ..Default::default() + }; + + self.run_on_address(Arc::new(config), ("0.0.0.0", 2222)) + .await?; + Ok(()) + } +} + +impl Server for AppServer { + type Handler = Self; + fn new_client(&mut self, _: Option) -> Self { + let s = self.clone(); + self.id += 1; + s + } +} + +#[async_trait] +impl Handler for AppServer { + type Error = anyhow::Error; + + async fn channel_open_session( + &mut self, + channel: Channel, + session: &mut Session, + ) -> Result { + { + let mut clients = self.clients.lock().await; + let terminal_handle = TerminalHandle { + handle: session.handle(), + sink: Vec::new(), + channel_id: channel.id(), + }; + + let backend = CrosstermBackend::new(terminal_handle.clone()); + let terminal = Terminal::new(backend)?; + clients.insert(self.id, terminal); + } + + Ok(true) + } + + async fn auth_publickey(&mut self, _: &str, _: &PublicKey) -> Result { + Ok(Auth::Accept) + } + + async fn data( + &mut self, + channel: ChannelId, + data: &[u8], + session: &mut Session, + ) -> Result<(), Self::Error> { + let app = self.app.clone(); + match data { + // Pressing 'q' closes the connection. + b"q" => { + self.clients.lock().await.remove(&self.id); + session.close(channel); + } + // Pressing 'c' resets the counter for the app. + // Every client sees the counter reset. + b"c" => { + app.lock().await.counter = 0; + } + _ => {} + } + + Ok(()) + } + + /// The client's pseudo-terminal window size has changed. + async fn window_change_request( + &mut self, + _: ChannelId, + col_width: u32, + row_height: u32, + _: u32, + _: u32, + _: &mut Session, + ) -> Result<(), Self::Error> { + let mut terminal = { + let clients = self.clients.lock().await; + clients.get(&self.id).unwrap().clone() + }; + let rect = Rect { + x: 0, + y: 0, + width: col_width as u16, + height: row_height as u16, + }; + terminal.resize(rect)?; + + Ok(()) + } +} + +#[tokio::main] +async fn main() { + let mut server = AppServer::new(); + server.run().await.expect("Failed running server"); +} diff --git a/russh/examples/remote_shell_call.rs b/russh/examples/remote_shell_call.rs deleted file mode 100644 index b77750b4..00000000 --- a/russh/examples/remote_shell_call.rs +++ /dev/null @@ -1,117 +0,0 @@ -use std::io::Write; -use std::path::Path; -use std::sync::Arc; -use std::time::Duration; - -use anyhow::Result; -use async_trait::async_trait; -use log::info; -use russh::*; -use russh_keys::*; -use tokio::net::ToSocketAddrs; - -#[tokio::main] -async fn main() -> Result<()> { - env_logger::builder() - .filter_level(log::LevelFilter::Debug) - .init(); - - let args: Vec = std::env::args().collect(); - let (host, key) = match args.get(1..3) { - Some(args) => (&args[0], &args[1]), - None => { - eprintln!("Usage: {} ", args[0]); - std::process::exit(1); - } - }; - - info!("Connecting to {host}"); - info!("Key path: {key}"); - - let mut ssh = Session::connect(key, "root", host).await?; - let r = ssh.call("whoami").await?; - assert!(r.success()); - println!("Result: {}", r.output()); - ssh.close().await?; - Ok(()) -} - -struct Client {} - -#[async_trait] -impl client::Handler for Client { - type Error = russh::Error; - - async fn check_server_key( - self, - _server_public_key: &key::PublicKey, - ) -> Result<(Self, bool), Self::Error> { - Ok((self, true)) - } -} - -pub struct Session { - session: client::Handle, -} - -impl Session { - async fn connect, A: ToSocketAddrs>( - key_path: P, - user: impl Into, - addrs: A, - ) -> Result { - let key_pair = load_secret_key(key_path, None)?; - let config = client::Config { - inactivity_timeout: Some(Duration::from_secs(5)), - ..<_>::default() - }; - let config = Arc::new(config); - let sh = Client {}; - let mut session = client::connect(config, addrs, sh).await?; - let _auth_res = session - .authenticate_publickey(user, Arc::new(key_pair)) - .await?; - Ok(Self { session }) - } - - async fn call(&mut self, command: &str) -> Result { - let mut channel = self.session.channel_open_session().await?; - channel.exec(true, command).await?; - let mut output = Vec::new(); - let mut code = None; - while let Some(msg) = channel.wait().await { - match msg { - russh::ChannelMsg::Data { ref data } => { - output.write_all(data).unwrap(); - } - russh::ChannelMsg::ExitStatus { exit_status } => { - code = Some(exit_status); - } - _ => {} - } - } - Ok(CommandResult { output, code }) - } - - async fn close(&mut self) -> Result<()> { - self.session - .disconnect(Disconnect::ByApplication, "", "English") - .await?; - Ok(()) - } -} - -struct CommandResult { - output: Vec, - code: Option, -} - -impl CommandResult { - fn output(&self) -> String { - String::from_utf8_lossy(&self.output).into() - } - - fn success(&self) -> bool { - self.code == Some(0) - } -} diff --git a/russh/examples/sftp_client.rs b/russh/examples/sftp_client.rs new file mode 100644 index 00000000..4870d057 --- /dev/null +++ b/russh/examples/sftp_client.rs @@ -0,0 +1,103 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use log::{error, info, LevelFilter}; +use russh::keys::*; +use russh::*; +use russh_sftp::client::SftpSession; +use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt}; + +struct Client; + +#[async_trait] +impl client::Handler for Client { + type Error = anyhow::Error; + + async fn check_server_key( + &mut self, + server_public_key: &key::PublicKey, + ) -> Result { + info!("check_server_key: {:?}", server_public_key); + Ok(true) + } + + async fn data( + &mut self, + channel: ChannelId, + data: &[u8], + _session: &mut client::Session, + ) -> Result<(), Self::Error> { + info!("data on channel {:?}: {}", channel, data.len()); + Ok(()) + } +} + +#[tokio::main] +async fn main() { + env_logger::builder() + .filter_level(LevelFilter::Debug) + .init(); + + let config = russh::client::Config::default(); + let sh = Client {}; + let mut session = russh::client::connect(Arc::new(config), ("localhost", 22), sh) + .await + .unwrap(); + if session.authenticate_password("root", "pass").await.unwrap() { + let channel = session.channel_open_session().await.unwrap(); + channel.request_subsystem(true, "sftp").await.unwrap(); + let sftp = SftpSession::new(channel.into_stream()).await.unwrap(); + info!("current path: {:?}", sftp.canonicalize(".").await.unwrap()); + + // create dir and symlink + let path = "./some_kind_of_dir"; + let symlink = "./symlink"; + + sftp.create_dir(path).await.unwrap(); + sftp.symlink(path, symlink).await.unwrap(); + + info!("dir info: {:?}", sftp.metadata(path).await.unwrap()); + info!( + "symlink info: {:?}", + sftp.symlink_metadata(path).await.unwrap() + ); + + // scanning directory + for entry in sftp.read_dir(".").await.unwrap() { + info!("file in directory: {:?}", entry.file_name()); + } + + sftp.remove_file(symlink).await.unwrap(); + sftp.remove_dir(path).await.unwrap(); + + // interaction with i/o + let filename = "test_new.txt"; + let mut file = sftp.create(filename).await.unwrap(); + info!("metadata by handle: {:?}", file.metadata().await.unwrap()); + + file.write_all(b"magic text").await.unwrap(); + info!("flush: {:?}", file.flush().await); // or file.sync_all() + info!( + "current cursor position: {:?}", + file.stream_position().await + ); + + let mut str = String::new(); + + file.rewind().await.unwrap(); + file.read_to_string(&mut str).await.unwrap(); + file.rewind().await.unwrap(); + + info!( + "our magical contents: {}, after rewind: {:?}", + str, + file.stream_position().await + ); + + file.shutdown().await.unwrap(); + sftp.remove_file(filename).await.unwrap(); + + // should fail because handle was closed + error!("should fail: {:?}", file.read_u8().await); + } +} diff --git a/russh/examples/sftp_server.rs b/russh/examples/sftp_server.rs index 4ce70153..2558fbb9 100644 --- a/russh/examples/sftp_server.rs +++ b/russh/examples/sftp_server.rs @@ -1,12 +1,14 @@ +use std::collections::HashMap; +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::Duration; + use async_trait::async_trait; use log::{error, info, LevelFilter}; -use russh::{ - server::{Auth, Msg, Session}, - Channel, ChannelId, -}; -use russh_keys::key::KeyPair; +use russh::keys::key::KeyPair; +use russh::server::{Auth, Msg, Server as _, Session}; +use russh::{Channel, ChannelId}; use russh_sftp::protocol::{File, FileAttributes, Handle, Name, Status, StatusCode, Version}; -use std::{collections::HashMap, net::SocketAddr, sync::Arc, time::Duration}; use tokio::sync::Mutex; #[derive(Clone)] @@ -43,38 +45,38 @@ impl SshSession { impl russh::server::Handler for SshSession { type Error = anyhow::Error; - async fn auth_password(self, user: &str, password: &str) -> Result<(Self, Auth), Self::Error> { + async fn auth_password(&mut self, user: &str, password: &str) -> Result { info!("credentials: {}, {}", user, password); - Ok((self, Auth::Accept)) + Ok(Auth::Accept) } async fn auth_publickey( - self, + &mut self, user: &str, public_key: &russh_keys::key::PublicKey, - ) -> Result<(Self, Auth), Self::Error> { + ) -> Result { info!("credentials: {}, {:?}", user, public_key); - Ok((self, Auth::Accept)) + Ok(Auth::Accept) } async fn channel_open_session( - mut self, + &mut self, channel: Channel, - session: Session, - ) -> Result<(Self, bool, Session), Self::Error> { + _session: &mut Session, + ) -> Result { { let mut clients = self.clients.lock().await; clients.insert(channel.id(), channel); } - Ok((self, true, session)) + Ok(true) } async fn subsystem_request( - mut self, + &mut self, channel_id: ChannelId, name: &str, - mut session: Session, - ) -> Result<(Self, Session), Self::Error> { + session: &mut Session, + ) -> Result<(), Self::Error> { info!("subsystem: {}", name); if name == "sftp" { @@ -86,24 +88,16 @@ impl russh::server::Handler for SshSession { session.channel_failure(channel_id); } - Ok((self, session)) + Ok(()) } } +#[derive(Default)] struct SftpSession { version: Option, root_dir_read_done: bool, } -impl Default for SftpSession { - fn default() -> Self { - Self { - version: None, - root_dir_read_done: false, - } - } -} - #[async_trait] impl russh_sftp::server::Handler for SftpSession { type Error = StatusCode; @@ -151,10 +145,12 @@ impl russh_sftp::server::Handler for SftpSession { files: vec![ File { filename: "foo".to_string(), + longname: "".to_string(), attrs: FileAttributes::default(), }, File { filename: "bar".to_string(), + longname: "".to_string(), attrs: FileAttributes::default(), }, ], @@ -169,6 +165,7 @@ impl russh_sftp::server::Handler for SftpSession { id, files: vec![File { filename: "/".to_string(), + longname: "".to_string(), attrs: FileAttributes::default(), }], }) @@ -185,23 +182,22 @@ async fn main() { auth_rejection_time: Duration::from_secs(3), auth_rejection_time_initial: Some(Duration::from_secs(0)), keys: vec![KeyPair::generate_ed25519().unwrap()], - inactivity_timeout: Some(Duration::from_secs(3600)), ..Default::default() }; - let server = Server; - - russh::server::run( - Arc::new(config), - ( - "0.0.0.0", - std::env::var("PORT") - .unwrap_or("22".to_string()) - .parse() - .unwrap(), - ), - server, - ) - .await - .unwrap(); + let mut server = Server; + + server + .run_on_address( + Arc::new(config), + ( + "0.0.0.0", + std::env::var("PORT") + .unwrap_or("22".to_string()) + .parse() + .unwrap(), + ), + ) + .await + .unwrap(); } diff --git a/russh/examples/test.rs b/russh/examples/test.rs index 034ea7e2..6cf6853b 100644 --- a/russh/examples/test.rs +++ b/russh/examples/test.rs @@ -1,11 +1,11 @@ -use async_trait::async_trait; +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; +use async_trait::async_trait; use log::debug; -use russh::server::{Auth, Msg, Session}; +use russh::keys::*; +use russh::server::{Auth, Msg, Server as _, Session}; use russh::*; -use russh_keys::*; -use std::collections::HashMap; -use std::sync::{Arc, Mutex}; #[tokio::main] async fn main() -> anyhow::Result<()> { @@ -16,13 +16,13 @@ async fn main() -> anyhow::Result<()> { .keys .push(russh_keys::key::KeyPair::generate_ed25519().unwrap()); let config = Arc::new(config); - let sh = Server { + let mut sh = Server { clients: Arc::new(Mutex::new(HashMap::new())), id: 0, }; tokio::time::timeout( std::time::Duration::from_secs(60), - russh::server::run(config, ("0.0.0.0", 2222), sh), + sh.run_on_address(config, ("0.0.0.0", 2222)), ) .await .unwrap_or(Ok(()))?; @@ -51,42 +51,38 @@ impl server::Handler for Server { type Error = anyhow::Error; async fn channel_open_session( - self, + &mut self, channel: Channel, - session: Session, - ) -> Result<(Self, bool, Session), Self::Error> { + _session: &mut Session, + ) -> Result { { debug!("channel open session"); let mut clients = self.clients.lock().unwrap(); clients.insert((self.id, channel.id()), channel); } - Ok((self, true, session)) + Ok(true) } /// The client requests a shell. #[allow(unused_variables)] async fn shell_request( - self, + &mut self, channel: ChannelId, - mut session: Session, - ) -> Result<(Self, Session), Self::Error> { + session: &mut Session, + ) -> Result<(), Self::Error> { session.request_success(); - Ok((self, session)) + Ok(()) } - async fn auth_publickey( - self, - _: &str, - _: &key::PublicKey, - ) -> Result<(Self, Auth), Self::Error> { - Ok((self, server::Auth::Accept)) + async fn auth_publickey(&mut self, _: &str, _: &key::PublicKey) -> Result { + Ok(server::Auth::Accept) } async fn data( - self, + &mut self, _channel: ChannelId, data: &[u8], - mut session: Session, - ) -> Result<(Self, Session), Self::Error> { + session: &mut Session, + ) -> Result<(), Self::Error> { debug!("data: {data:?}"); { let mut clients = self.clients.lock().unwrap(); @@ -94,6 +90,6 @@ impl server::Handler for Server { session.data(channel.id(), CryptoVec::from(data.to_vec())); } } - Ok((self, session)) + Ok(()) } } diff --git a/russh/src/auth.rs b/russh/src/auth.rs index e64f3a42..0f3c3da6 100644 --- a/russh/src/auth.rs +++ b/russh/src/auth.rs @@ -16,11 +16,13 @@ use std::sync::Arc; use bitflags::bitflags; -use russh_cryptovec::CryptoVec; -use russh_keys::{encoding, key}; +use ssh_key::Certificate; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; +use crate::keys::{encoding, key}; +use crate::CryptoVec; + bitflags! { /// Set of authentication methods, represented by bit flags. #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -77,10 +79,22 @@ impl Signer #[derive(Debug)] pub enum Method { None, - Password { password: String }, - PublicKey { key: Arc }, - FuturePublicKey { key: key::PublicKey }, - KeyboardInteractive { submethods: String }, + Password { + password: String, + }, + PublicKey { + key: Arc, + }, + OpenSSHCertificate { + key: Arc, + cert: Certificate, + }, + FuturePublicKey { + key: key::PublicKey, + }, + KeyboardInteractive { + submethods: String, + }, // Hostbased, } diff --git a/russh/src/cert.rs b/russh/src/cert.rs new file mode 100644 index 00000000..9991e81b --- /dev/null +++ b/russh/src/cert.rs @@ -0,0 +1,62 @@ +use ssh_encoding::Encode; +use ssh_key::{Algorithm, Certificate, EcdsaCurve}; + +use crate::key::PubKey; +use crate::keys::encoding::Encoding; +use crate::negotiation::Named; +use crate::CryptoVec; + +/// OpenSSH certificate for DSA public key +const CERT_DSA: &str = "ssh-dss-cert-v01@openssh.com"; + +/// OpenSSH certificate for ECDSA (NIST P-256) public key +const CERT_ECDSA_SHA2_P256: &str = "ecdsa-sha2-nistp256-cert-v01@openssh.com"; + +/// OpenSSH certificate for ECDSA (NIST P-384) public key +const CERT_ECDSA_SHA2_P384: &str = "ecdsa-sha2-nistp384-cert-v01@openssh.com"; + +/// OpenSSH certificate for ECDSA (NIST P-521) public key +const CERT_ECDSA_SHA2_P521: &str = "ecdsa-sha2-nistp521-cert-v01@openssh.com"; + +/// OpenSSH certificate for Ed25519 public key +const CERT_ED25519: &str = "ssh-ed25519-cert-v01@openssh.com"; + +/// OpenSSH certificate with RSA public key +const CERT_RSA: &str = "ssh-rsa-cert-v01@openssh.com"; + +/// OpenSSH certificate for ECDSA (NIST P-256) U2F/FIDO security key +const CERT_SK_ECDSA_SHA2_P256: &str = "sk-ecdsa-sha2-nistp256-cert-v01@openssh.com"; + +/// OpenSSH certificate for Ed25519 U2F/FIDO security key +const CERT_SK_SSH_ED25519: &str = "sk-ssh-ed25519-cert-v01@openssh.com"; + +/// None +const NONE: &str = "none"; + +impl PubKey for Certificate { + fn push_to(&self, buffer: &mut CryptoVec) { + let mut cert_encoded = Vec::new(); + let _ = self.encode(&mut cert_encoded); + + buffer.extend_ssh_string(&cert_encoded); + } +} + +impl Named for Certificate { + fn name(&self) -> &'static str { + match self.algorithm() { + Algorithm::Dsa => CERT_DSA, + Algorithm::Ecdsa { curve } => match curve { + EcdsaCurve::NistP256 => CERT_ECDSA_SHA2_P256, + EcdsaCurve::NistP384 => CERT_ECDSA_SHA2_P384, + EcdsaCurve::NistP521 => CERT_ECDSA_SHA2_P521, + }, + Algorithm::Ed25519 => CERT_ED25519, + Algorithm::Rsa { .. } => CERT_RSA, + Algorithm::SkEcdsaSha2NistP256 => CERT_SK_ECDSA_SHA2_P256, + Algorithm::SkEd25519 => CERT_SK_SSH_ED25519, + Algorithm::Other(_) => NONE, + _ => NONE, + } + } +} diff --git a/russh/src/channel_stream/mod.rs b/russh/src/channel_stream/mod.rs deleted file mode 100644 index 80ee89a6..00000000 --- a/russh/src/channel_stream/mod.rs +++ /dev/null @@ -1,129 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. -// Originally from microsoft/dev-tunnels - -mod read_buffer; - -use std::io; -use std::pin::Pin; -use std::task::Poll; - -use log::debug; -use tokio::io::{AsyncRead, AsyncWrite}; -use tokio::sync::mpsc; - -use self::read_buffer::ReadBuffer; - -/// AsyncRead/AsyncWrite wrapper for SSH Channels -pub struct ChannelStream { - incoming: mpsc::UnboundedReceiver>, - outgoing: mpsc::UnboundedSender>, - - readbuf: ReadBuffer, - - is_write_fut_valid: bool, - write_fut: tokio_util::sync::ReusableBoxFuture<'static, Result<(), Vec>>, -} - -impl ChannelStream { - pub fn new() -> ( - Self, - mpsc::UnboundedReceiver>, - mpsc::UnboundedSender>, - ) { - let (w_tx, w_rx) = mpsc::unbounded_channel(); - let (r_tx, r_rx) = mpsc::unbounded_channel(); - ( - ChannelStream { - incoming: w_rx, - outgoing: r_tx, - readbuf: ReadBuffer::default(), - is_write_fut_valid: false, - write_fut: tokio_util::sync::ReusableBoxFuture::new(make_client_write_fut(None)), - }, - r_rx, - w_tx, - ) - } -} - -/// Makes a future that writes to the russh handle. This general approach was -/// taken from https://docs.rs/tokio-util/0.7.3/tokio_util/sync/struct.PollSender.html -/// This is just like make_server_write_fut, but for clients (they don't share a trait...) -async fn make_client_write_fut( - data: Option<(mpsc::UnboundedSender>, Vec)>, -) -> Result<(), Vec> { - match data { - Some((sender, data)) => sender.send(data).map_err(|e| e.0), - None => unreachable!("this future should not be pollable in this state"), - } -} - -impl AsyncWrite for ChannelStream { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &[u8], - ) -> Poll> { - if !self.is_write_fut_valid { - let outgoing = self.outgoing.clone(); - self.write_fut - .set(make_client_write_fut(Some((outgoing, buf.to_vec())))); - self.is_write_fut_valid = true; - } - - self.poll_flush(cx).map(|r| r.map(|_| buf.len())) - } - - fn poll_flush( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - if !self.is_write_fut_valid { - return Poll::Ready(Ok(())); - } - - match self.write_fut.poll(cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(Ok(_)) => { - self.is_write_fut_valid = false; - Poll::Ready(Ok(())) - } - Poll::Ready(Err(_)) => { - self.is_write_fut_valid = false; - debug!("ChannelStream AsyncWrite EOF"); - Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, "EOF"))) - } - } - } - - fn poll_shutdown( - self: Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> Poll> { - if let Err(err) = self.outgoing.send("".into()) { - let err = format!("{err:?}"); - return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, err))) - } - Poll::Ready(Ok(())) - } -} - -impl AsyncRead for ChannelStream { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> Poll> { - if let Some((v, s)) = self.readbuf.take_data() { - return self.readbuf.put_data(buf, v, s); - } - - let x = self.incoming.poll_recv(cx); - match x { - Poll::Ready(Some(msg)) => self.readbuf.put_data(buf, msg, 0), - Poll::Ready(None) => Poll::Ready(Ok(())), - Poll::Pending => Poll::Pending, - } - } -} diff --git a/russh/src/channel_stream/read_buffer.rs b/russh/src/channel_stream/read_buffer.rs deleted file mode 100644 index 2521e068..00000000 --- a/russh/src/channel_stream/read_buffer.rs +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. -// Originally from microsoft/dev-tunnels - -use std::task::Poll; - -/// Helper used when converting Future interfaces to poll-based interfaces. -/// Stores excess data that can be reused on future polls. -#[derive(Default)] -pub(crate) struct ReadBuffer(Option<(Vec, usize)>); - -impl ReadBuffer { - /// Removes any data stored in the read buffer - pub fn take_data(&mut self) -> Option<(Vec, usize)> { - self.0.take() - } - - /// Writes as many bytes as possible to the readbuf, stashing any extra. - pub fn put_data( - &mut self, - target: &mut tokio::io::ReadBuf<'_>, - bytes: Vec, - start: usize, - ) -> Poll> { - if target.remaining() >= bytes.len() - start { - if start < bytes.len() { - #[allow(clippy::indexing_slicing)] - target.put_slice(&bytes[start..]); - } - self.0 = None; - } else { - let end = start + target.remaining(); - if start < bytes.len() && end <= bytes.len() { - #[allow(clippy::indexing_slicing)] - target.put_slice(&bytes[start..end]); - } - self.0 = Some((bytes, end)); - } - - Poll::Ready(Ok(())) - } -} diff --git a/russh/src/channels/channel_ref.rs b/russh/src/channels/channel_ref.rs new file mode 100644 index 00000000..d924bb11 --- /dev/null +++ b/russh/src/channels/channel_ref.rs @@ -0,0 +1,35 @@ +use std::sync::Arc; + +use tokio::sync::mpsc::UnboundedSender; +use tokio::sync::Mutex; + +use crate::ChannelMsg; + +/// A handle to the [`super::Channel`]'s to be able to transmit messages +/// to it and update it's `window_size`. +#[derive(Debug)] +pub struct ChannelRef { + pub(super) sender: UnboundedSender, + pub(super) window_size: Arc>, +} + +impl ChannelRef { + pub fn new(sender: UnboundedSender) -> Self { + Self { + sender, + window_size: Default::default(), + } + } + + pub fn window_size(&self) -> &Arc> { + &self.window_size + } +} + +impl std::ops::Deref for ChannelRef { + type Target = UnboundedSender; + + fn deref(&self) -> &Self::Target { + &self.sender + } +} diff --git a/russh/src/channels/channel_stream.rs b/russh/src/channels/channel_stream.rs new file mode 100644 index 00000000..9224ca67 --- /dev/null +++ b/russh/src/channels/channel_stream.rs @@ -0,0 +1,63 @@ +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use tokio::io::{AsyncRead, AsyncWrite}; + +use super::io::{ChannelRx, ChannelTx}; +use super::{ChannelId, ChannelMsg}; + +/// AsyncRead/AsyncWrite wrapper for SSH Channels +pub struct ChannelStream +where + S: From<(ChannelId, ChannelMsg)> + 'static, +{ + tx: ChannelTx, + rx: ChannelRx<'static, S>, +} + +impl ChannelStream +where + S: From<(ChannelId, ChannelMsg)>, +{ + pub(super) fn new(tx: ChannelTx, rx: ChannelRx<'static, S>) -> Self { + Self { tx, rx } + } +} + +impl AsyncRead for ChannelStream +where + S: From<(ChannelId, ChannelMsg)>, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.rx).poll_read(cx, buf) + } +} + +impl AsyncWrite for ChannelStream +where + S: From<(ChannelId, ChannelMsg)> + 'static + Send + Sync, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.tx).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.tx).poll_flush(cx) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.tx).poll_shutdown(cx) + } +} diff --git a/russh/src/channels/io/mod.rs b/russh/src/channels/io/mod.rs new file mode 100644 index 00000000..976ddf41 --- /dev/null +++ b/russh/src/channels/io/mod.rs @@ -0,0 +1,48 @@ +use super::{Channel, ChannelId, ChannelMsg}; + +mod rx; +pub use rx::ChannelRx; + +mod tx; +pub use tx::ChannelTx; + +/// An enum with the ability to hold either an owned [`Channel`] +/// or a `&mut` ref to it. +#[derive(Debug)] +pub enum ChannelAsMut<'i, S> +where + S: From<(ChannelId, ChannelMsg)>, +{ + Owned(Channel), + RefMut(&'i mut Channel), +} + +impl<'i, S> AsMut> for ChannelAsMut<'i, S> +where + S: From<(ChannelId, ChannelMsg)>, +{ + fn as_mut(&mut self) -> &mut Channel { + match self { + Self::Owned(channel) => channel, + Self::RefMut(ref_mut) => ref_mut, + } + } +} + +impl From> for ChannelAsMut<'static, S> +where + S: From<(ChannelId, ChannelMsg)>, +{ + fn from(value: Channel) -> Self { + Self::Owned(value) + } +} + +impl<'i, S> From<&'i mut Channel> for ChannelAsMut<'i, S> +where + S: From<(ChannelId, ChannelMsg)>, +{ + fn from(value: &'i mut Channel) -> Self { + Self::RefMut(value) + } +} diff --git a/russh/src/channels/io/rx.rs b/russh/src/channels/io/rx.rs new file mode 100644 index 00000000..745ae340 --- /dev/null +++ b/russh/src/channels/io/rx.rs @@ -0,0 +1,91 @@ +use std::io; +use std::pin::Pin; +use std::task::{ready, Context, Poll}; + +use tokio::io::AsyncRead; + +use super::{ChannelAsMut, ChannelMsg}; +use crate::ChannelId; + +#[derive(Debug)] +pub struct ChannelRx<'i, S> +where + S: From<(ChannelId, ChannelMsg)>, +{ + channel: ChannelAsMut<'i, S>, + buffer: Option<(ChannelMsg, usize)>, + + ext: Option, +} + +impl<'i, S> ChannelRx<'i, S> +where + S: From<(ChannelId, ChannelMsg)>, +{ + pub fn new(channel: impl Into>, ext: Option) -> Self { + Self { + channel: channel.into(), + buffer: None, + ext, + } + } +} + +impl<'i, S> AsyncRead for ChannelRx<'i, S> +where + S: From<(ChannelId, ChannelMsg)>, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + let (msg, mut idx) = match self.buffer.take() { + Some(msg) => msg, + None => match ready!(self.channel.as_mut().receiver.poll_recv(cx)) { + Some(msg) => (msg, 0), + None => return Poll::Ready(Ok(())), + }, + }; + + match (&msg, self.ext) { + (ChannelMsg::Data { data }, None) => { + let readable = buf.remaining().min(data.len() - idx); + + // Clamped to maximum `buf.remaining()` and `data.len() - idx` with `.min` + #[allow(clippy::indexing_slicing)] + buf.put_slice(&data[idx..idx + readable]); + idx += readable; + + if idx != data.len() { + self.buffer = Some((msg, idx)); + } + + Poll::Ready(Ok(())) + } + (ChannelMsg::ExtendedData { data, ext }, Some(target)) if *ext == target => { + let readable = buf.remaining().min(data.len() - idx); + + // Clamped to maximum `buf.remaining()` and `data.len() - idx` with `.min` + #[allow(clippy::indexing_slicing)] + buf.put_slice(&data[idx..idx + readable]); + idx += readable; + + if idx != data.len() { + self.buffer = Some((msg, idx)); + } + + Poll::Ready(Ok(())) + } + (ChannelMsg::Eof, _) => { + self.channel.as_mut().receiver.close(); + + Poll::Ready(Ok(())) + } + _ => { + cx.waker().wake_by_ref(); + Poll::Pending + } + } + } +} diff --git a/russh/src/channels/io/tx.rs b/russh/src/channels/io/tx.rs new file mode 100644 index 00000000..47ce2726 --- /dev/null +++ b/russh/src/channels/io/tx.rs @@ -0,0 +1,145 @@ +use std::io; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{ready, Context, Poll}; + +use futures::FutureExt; +use tokio::io::AsyncWrite; +use tokio::sync::mpsc::error::SendError; +use tokio::sync::mpsc::{self, OwnedPermit}; +use tokio::sync::{Mutex, OwnedMutexGuard}; + +use super::ChannelMsg; +use crate::{ChannelId, CryptoVec}; + +type BoxedThreadsafeFuture = Pin>>; +type OwnedPermitFuture = + BoxedThreadsafeFuture, ChannelMsg, usize), SendError<()>>>; + +pub struct ChannelTx { + sender: mpsc::Sender, + send_fut: Option>, + id: ChannelId, + + window_size_fut: Option>>, + window_size: Arc>, + max_packet_size: u32, + ext: Option, +} + +impl ChannelTx +where + S: From<(ChannelId, ChannelMsg)> + 'static + Send, +{ + pub fn new( + sender: mpsc::Sender, + id: ChannelId, + window_size: Arc>, + max_packet_size: u32, + ext: Option, + ) -> Self { + Self { + sender, + send_fut: None, + id, + window_size, + window_size_fut: None, + max_packet_size, + ext, + } + } + + fn poll_mk_msg(&mut self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<(ChannelMsg, usize)> { + let window_size = self.window_size.clone(); + let window_size_fut = self + .window_size_fut + .get_or_insert_with(|| Box::pin(window_size.lock_owned())); + let mut window_size = ready!(window_size_fut.poll_unpin(cx)); + self.window_size_fut.take(); + + let writable = (self.max_packet_size) + .min(*window_size) + .min(buf.len() as u32) as usize; + if writable == 0 { + // TODO fix this busywait + cx.waker().wake_by_ref(); + return Poll::Pending; + } + let mut data = CryptoVec::new_zeroed(writable); + #[allow(clippy::indexing_slicing)] // Clamped to maximum `buf.len()` with `.min` + data.copy_from_slice(&buf[..writable]); + data.resize(writable); + + *window_size -= writable as u32; + drop(window_size); + + let msg = match self.ext { + None => ChannelMsg::Data { data }, + Some(ext) => ChannelMsg::ExtendedData { data, ext }, + }; + + Poll::Ready((msg, writable)) + } + + fn activate(&mut self, msg: ChannelMsg, writable: usize) -> &mut OwnedPermitFuture { + use futures::TryFutureExt; + self.send_fut.insert(Box::pin( + self.sender + .clone() + .reserve_owned() + .map_ok(move |p| (p, msg, writable)), + )) + } + + fn handle_write_result( + &mut self, + r: Result<(OwnedPermit, ChannelMsg, usize), SendError<()>>, + ) -> Result { + self.send_fut = None; + match r { + Ok((permit, msg, writable)) => { + permit.send((self.id, msg).into()); + Ok(writable) + } + Err(SendError(())) => Err(io::Error::new(io::ErrorKind::BrokenPipe, "channel closed")), + } + } +} + +impl AsyncWrite for ChannelTx +where + S: From<(ChannelId, ChannelMsg)> + 'static + Send, +{ + #[allow(clippy::too_many_lines)] + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let send_fut = if let Some(x) = self.send_fut.as_mut() { + x + } else { + let (msg, writable) = ready!(self.poll_mk_msg(cx, buf)); + self.activate(msg, writable) + }; + let r = ready!(send_fut.as_mut().poll_unpin(cx)); + Poll::Ready(self.handle_write_result(r)) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let send_fut = if let Some(x) = self.send_fut.as_mut() { + x + } else { + self.activate(ChannelMsg::Eof, 0) + }; + let r = ready!(send_fut.as_mut().poll_unpin(cx)).map(|(p, _, _)| (p, ChannelMsg::Eof, 0)); + Poll::Ready(self.handle_write_result(r).map(drop)) + } +} diff --git a/russh/src/channels.rs b/russh/src/channels/mod.rs similarity index 54% rename from russh/src/channels.rs rename to russh/src/channels/mod.rs index e5da311e..76e2f371 100644 --- a/russh/src/channels.rs +++ b/russh/src/channels/mod.rs @@ -1,8 +1,19 @@ -use russh_cryptovec::CryptoVec; +use std::{pin::Pin, sync::Arc}; + +use futures::{Future, FutureExt as _}; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio::sync::mpsc::{Sender, UnboundedReceiver}; -use log::debug; +use tokio::sync::Mutex; + +use crate::{ChannelId, ChannelOpenFailure, CryptoVec, Error, Pty, Sig}; + +pub mod io; + +mod channel_ref; +pub use channel_ref::ChannelRef; -use crate::{ChannelId, ChannelOpenFailure, ChannelStream, Error, Pty, Sig}; +mod channel_stream; +pub use channel_stream::ChannelStream; #[derive(Debug)] #[non_exhaustive] @@ -21,6 +32,7 @@ pub enum ChannelMsg { ext: u32, }, Eof, + Close, /// (client only) RequestPty { want_reply: bool, @@ -98,8 +110,6 @@ pub enum ChannelMsg { Success, /// (server only) Failure, - /// (server only) - Close, OpenFailure(ChannelOpenFailure), } @@ -111,7 +121,7 @@ pub struct Channel> { pub(crate) sender: Sender, pub(crate) receiver: UnboundedReceiver, pub(crate) max_packet_size: u32, - pub(crate) window_size: u32, + pub(crate) window_size: Arc>, } impl> std::fmt::Debug for Channel { @@ -120,21 +130,45 @@ impl> std::fmt::Debug for Channel { } } -impl + Send + 'static> Channel { - pub fn id(&self) -> ChannelId { - self.id +impl + Send + Sync + 'static> Channel { + pub(crate) fn new( + id: ChannelId, + sender: Sender, + max_packet_size: u32, + window_size: u32, + ) -> (Self, ChannelRef) { + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + let window_size = Arc::new(Mutex::new(window_size)); + + ( + Self { + id, + sender, + receiver: rx, + max_packet_size, + window_size: window_size.clone(), + }, + ChannelRef { + sender: tx, + window_size, + }, + ) } /// Returns the min between the maximum packet size and the /// remaining window size in the channel. - pub fn writable_packet_size(&self) -> usize { - self.max_packet_size.min(self.window_size) as usize + pub async fn writable_packet_size(&self) -> usize { + self.max_packet_size.min(*self.window_size.lock().await) as usize + } + + pub fn id(&self) -> ChannelId { + self.id } /// Request a pseudo-terminal with the given characteristics. #[allow(clippy::too_many_arguments)] // length checked pub async fn request_pty( - &mut self, + &self, want_reply: bool, term: &str, col_width: u32, @@ -152,42 +186,33 @@ impl + Send + 'static> Channel { pix_height, terminal_modes: terminal_modes.to_vec(), }) - .await?; - Ok(()) + .await } /// Request a remote shell. - pub async fn request_shell(&mut self, want_reply: bool) -> Result<(), Error> { - self.send_msg(ChannelMsg::RequestShell { want_reply }) - .await?; - Ok(()) + pub async fn request_shell(&self, want_reply: bool) -> Result<(), Error> { + self.send_msg(ChannelMsg::RequestShell { want_reply }).await } /// Execute a remote program (will be passed to a shell). This can /// be used to implement scp (by calling a remote scp and /// tunneling to its standard input). - pub async fn exec>>( - &mut self, - want_reply: bool, - command: A, - ) -> Result<(), Error> { + pub async fn exec>>(&self, want_reply: bool, command: A) -> Result<(), Error> { self.send_msg(ChannelMsg::Exec { want_reply, command: command.into(), }) - .await?; - Ok(()) + .await } /// Signal a remote process. - pub async fn signal(&mut self, signal: Sig) -> Result<(), Error> { - self.send_msg(ChannelMsg::Signal { signal }).await?; - Ok(()) + pub async fn signal(&self, signal: Sig) -> Result<(), Error> { + self.send_msg(ChannelMsg::Signal { signal }).await } /// Request the start of a subsystem with the given name. pub async fn request_subsystem>( - &mut self, + &self, want_reply: bool, name: A, ) -> Result<(), Error> { @@ -195,8 +220,7 @@ impl + Send + 'static> Channel { want_reply, name: name.into(), }) - .await?; - Ok(()) + .await } /// Request X11 forwarding through an already opened X11 @@ -204,7 +228,7 @@ impl + Send + 'static> Channel { /// [RFC4254](https://tools.ietf.org/html/rfc4254#section-6.3.1) /// for security issues related to cookies. pub async fn request_x11, B: Into>( - &mut self, + &self, want_reply: bool, single_connection: bool, x11_authentication_protocol: A, @@ -218,13 +242,12 @@ impl + Send + 'static> Channel { x11_authentication_cookie: x11_authentication_cookie.into(), x11_screen_number, }) - .await?; - Ok(()) + .await } /// Set a remote environment variable. pub async fn set_env, B: Into>( - &mut self, + &self, want_reply: bool, variable_name: A, variable_value: B, @@ -234,13 +257,12 @@ impl + Send + 'static> Channel { variable_name: variable_name.into(), variable_value: variable_value.into(), }) - .await?; - Ok(()) + .await } /// Inform the server that our window size has changed. pub async fn window_change( - &mut self, + &self, col_width: u32, row_height: u32, pix_width: u32, @@ -252,107 +274,68 @@ impl + Send + 'static> Channel { pix_width, pix_height, }) - .await?; - Ok(()) + .await } /// Inform the server that we will accept agent forwarding channels - pub async fn agent_forward(&mut self, want_reply: bool) -> Result<(), Error> { - self.send_msg(ChannelMsg::AgentForward { want_reply }) - .await?; - Ok(()) + pub async fn agent_forward(&self, want_reply: bool) -> Result<(), Error> { + self.send_msg(ChannelMsg::AgentForward { want_reply }).await } /// Send data to a channel. - pub async fn data(&mut self, data: R) -> Result<(), Error> { + pub async fn data(&self, data: R) -> Result<(), Error> { self.send_data(None, data).await } /// Send data to a channel. The number of bytes added to the /// "sending pipeline" (to be processed by the event loop) is /// returned. - pub async fn extended_data( - &mut self, + pub async fn extended_data( + &self, ext: u32, data: R, ) -> Result<(), Error> { self.send_data(Some(ext), data).await } - async fn send_data( - &mut self, + async fn send_data( + &self, ext: Option, mut data: R, ) -> Result<(), Error> { - let mut total = 0; - loop { - // wait for the window to be restored. - while self.window_size == 0 { - match self.receiver.recv().await { - Some(ChannelMsg::WindowAdjusted { new_size }) => { - debug!("window adjusted: {:?}", new_size); - self.window_size = new_size; - break; - } - Some(msg) => { - debug!("unexpected channel msg: {:?}", msg); - } - None => break, - } - } - debug!( - "sending data, self.window_size = {:?}, self.max_packet_size = {:?}, total = {:?}", - self.window_size, self.max_packet_size, total - ); - let sendable = self.window_size.min(self.max_packet_size) as usize; - - debug!("sendable {:?}", sendable); - - // If we can not send anymore, continue - // and wait for server window adjustment - if sendable == 0 { - continue; - } + let mut tx = self.make_writer_ext(ext); + + tokio::io::copy(&mut data, &mut tx).await?; - let mut c = CryptoVec::new_zeroed(sendable); - let n = data.read(&mut c[..]).await?; - total += n; - c.resize(n); - self.window_size -= n as u32; - self.send_data_packet(ext, c).await?; - if n == 0 { - break; - } else if self.window_size > 0 { - continue; - } - } Ok(()) } - async fn send_data_packet(&mut self, ext: Option, data: CryptoVec) -> Result<(), Error> { - self.send_msg(if let Some(ext) = ext { - ChannelMsg::ExtendedData { ext, data } - } else { - ChannelMsg::Data { data } - }) - .await?; - Ok(()) + pub async fn eof(&self) -> Result<(), Error> { + self.send_msg(ChannelMsg::Eof).await } - pub async fn eof(&mut self) -> Result<(), Error> { - self.send_msg(ChannelMsg::Eof).await?; - Ok(()) + /// Request that the channel be closed. + pub async fn close(&self) -> Result<(), Error> { + self.send_msg(ChannelMsg::Close).await } + /// Get a `FnOnce` that can be used to send a signal through this channel + pub fn get_signal_sender( + &self, + ) -> impl FnOnce(Sig) -> Pin> + std::marker::Send>> + { + let sender = self.sender.clone(); + let id = self.id; - /// Wait for data to come. - pub async fn wait(&mut self) -> Option { - match self.receiver.recv().await { - Some(ChannelMsg::WindowAdjusted { new_size }) => { - self.window_size = new_size; - Some(ChannelMsg::WindowAdjusted { new_size }) + move |signal| { + async move { + sender + .send((id, ChannelMsg::Signal { signal }).into()) + .await + .map_err(|_| Error::SendError)?; + + Ok(()) } - Some(msg) => Some(msg), - None => None, + .boxed() } } @@ -363,51 +346,53 @@ impl + Send + 'static> Channel { .map_err(|_| Error::SendError) } - /// Request that the channel be closed. - pub async fn close(&self) -> Result<(), Error> { - self.send_msg(ChannelMsg::Close).await?; - Ok(()) + /// Awaits an incoming [`ChannelMsg`], this method returns [`None`] if the channel has been closed. + pub async fn wait(&mut self) -> Option { + self.receiver.recv().await } - pub fn into_stream(mut self) -> ChannelStream { - let (stream, mut r_rx, w_tx) = ChannelStream::new(); - - tokio::spawn(async move { - loop { - tokio::select! { - data = r_rx.recv() => { - match data { - Some(data) if !data.is_empty() => self.data(&data[..]).await?, - Some(_) => { - log::debug!("closing chan {:?}, received empty data", &self.id); - self.eof().await?; - self.close().await?; - break; - }, - None => { - self.close().await?; - break - } - } - }, - msg = self.wait() => { - match msg { - Some(ChannelMsg::Data { data }) => { - w_tx.send(data[..].into()).map_err(|_| crate::Error::SendError)?; - } - Some(ChannelMsg::Eof) => { - // Send a 0-length chunk to indicate EOF. - w_tx.send("".into()).map_err(|_| crate::Error::SendError)?; - break - } - None => break, - _ => (), - } - } - } - } - Ok::<_, crate::Error>(()) - }); - stream + /// Consume the [`Channel`] to produce a bidirectionnal stream, + /// sending and receiving [`ChannelMsg::Data`] as `AsyncRead` + `AsyncWrite`. + pub fn into_stream(self) -> ChannelStream { + ChannelStream::new( + io::ChannelTx::new( + self.sender.clone(), + self.id, + self.window_size.clone(), + self.max_packet_size, + None, + ), + io::ChannelRx::new(self, None), + ) + } + + /// Make a reader for the [`Channel`] to receive [`ChannelMsg::Data`] + /// through the `AsyncRead` trait. + pub fn make_reader(&mut self) -> impl AsyncRead + '_ { + self.make_reader_ext(None) + } + + /// Make a reader for the [`Channel`] to receive [`ChannelMsg::Data`] or [`ChannelMsg::ExtendedData`] + /// depending on the `ext` parameter, through the `AsyncRead` trait. + pub fn make_reader_ext(&mut self, ext: Option) -> impl AsyncRead + '_ { + io::ChannelRx::new(self, ext) + } + + /// Make a writer for the [`Channel`] to send [`ChannelMsg::Data`] + /// through the `AsyncWrite` trait. + pub fn make_writer(&self) -> impl AsyncWrite { + self.make_writer_ext(None) + } + + /// Make a writer for the [`Channel`] to send [`ChannelMsg::Data`] or [`ChannelMsg::ExtendedData`] + /// depending on the `ext` parameter, through the `AsyncWrite` trait. + pub fn make_writer_ext(&self, ext: Option) -> impl AsyncWrite { + io::ChannelTx::new( + self.sender.clone(), + self.id, + self.window_size.clone(), + self.max_packet_size, + ext, + ) } } diff --git a/russh/src/cipher/block.rs b/russh/src/cipher/block.rs index ccd6a4de..aa4ecdf6 100644 --- a/russh/src/cipher/block.rs +++ b/russh/src/cipher/block.rs @@ -11,6 +11,7 @@ // limitations under the License. // +use std::convert::TryInto; use std::marker::PhantomData; use aes::cipher::{IvSizeUser, KeyIvInit, KeySizeUser, StreamCipher}; @@ -21,9 +22,9 @@ use super::super::Error; use super::PACKET_LENGTH_LEN; use crate::mac::{Mac, MacAlgorithm}; -pub struct SshBlockCipher(pub PhantomData); +pub struct SshBlockCipher(pub PhantomData); -impl super::Cipher +impl super::Cipher for SshBlockCipher { fn key_len(&self) -> usize { @@ -73,29 +74,44 @@ impl su } } -pub struct OpeningKey { - cipher: C, - mac: Box, +pub struct OpeningKey { + pub(crate) cipher: C, + pub(crate) mac: Box, } -pub struct SealingKey { - cipher: C, - mac: Box, +pub struct SealingKey { + pub(crate) cipher: C, + pub(crate) mac: Box, } -impl super::OpeningKey for OpeningKey { +impl super::OpeningKey for OpeningKey { + fn packet_length_to_read_for_block_length(&self) -> usize { + 16 + } + fn decrypt_packet_length( &self, _sequence_number: u32, - mut encrypted_packet_length: [u8; 4], + encrypted_packet_length: &[u8], ) -> [u8; 4] { + let mut first_block = [0u8; 16]; + // Fine because of self.packet_length_to_read_for_block_length() + #[allow(clippy::indexing_slicing)] + first_block.copy_from_slice(&encrypted_packet_length[..16]); + if self.mac.is_etm() { - encrypted_packet_length + // Fine because of self.packet_length_to_read_for_block_length() + #[allow(clippy::unwrap_used, clippy::indexing_slicing)] + encrypted_packet_length[..4].try_into().unwrap() } else { // Work around uncloneable Aes<> let mut cipher: C = unsafe { std::ptr::read(&self.cipher as *const C) }; - cipher.apply_keystream(&mut encrypted_packet_length); - encrypted_packet_length + + cipher.decrypt_data(&mut first_block); + + // Fine because of self.packet_length_to_read_for_block_length() + #[allow(clippy::unwrap_used, clippy::indexing_slicing)] + first_block[..4].try_into().unwrap() } } @@ -118,9 +134,9 @@ impl super::OpeningKey for OpeningKe } #[allow(clippy::indexing_slicing)] self.cipher - .apply_keystream(&mut ciphertext_in_plaintext_out[PACKET_LENGTH_LEN..]); + .decrypt_data(&mut ciphertext_in_plaintext_out[PACKET_LENGTH_LEN..]); } else { - self.cipher.apply_keystream(ciphertext_in_plaintext_out); + self.cipher.decrypt_data(ciphertext_in_plaintext_out); if !self .mac @@ -129,11 +145,13 @@ impl super::OpeningKey for OpeningKe return Err(Error::PacketAuth); } } - Ok(ciphertext_in_plaintext_out) + + #[allow(clippy::indexing_slicing)] + Ok(&ciphertext_in_plaintext_out[PACKET_LENGTH_LEN..]) } } -impl super::SealingKey for SealingKey { +impl super::SealingKey for SealingKey { fn padding_length(&self, payload: &[u8]) -> usize { let block_size = 16; @@ -174,13 +192,28 @@ impl super::SealingKey for SealingKe if self.mac.is_etm() { #[allow(clippy::indexing_slicing)] self.cipher - .apply_keystream(&mut plaintext_in_ciphertext_out[PACKET_LENGTH_LEN..]); + .encrypt_data(&mut plaintext_in_ciphertext_out[PACKET_LENGTH_LEN..]); self.mac .compute(sequence_number, plaintext_in_ciphertext_out, tag_out); } else { self.mac .compute(sequence_number, plaintext_in_ciphertext_out, tag_out); - self.cipher.apply_keystream(plaintext_in_ciphertext_out); + self.cipher.encrypt_data(plaintext_in_ciphertext_out); } } } + +pub trait BlockStreamCipher { + fn encrypt_data(&mut self, data: &mut [u8]); + fn decrypt_data(&mut self, data: &mut [u8]); +} + +impl BlockStreamCipher for T { + fn encrypt_data(&mut self, data: &mut [u8]) { + self.apply_keystream(data); + } + + fn decrypt_data(&mut self, data: &mut [u8]) { + self.apply_keystream(data); + } +} diff --git a/russh/src/cipher/cbc.rs b/russh/src/cipher/cbc.rs new file mode 100644 index 00000000..87a0c66a --- /dev/null +++ b/russh/src/cipher/cbc.rs @@ -0,0 +1,53 @@ +use aes::cipher::{ + BlockCipher, BlockDecrypt, BlockDecryptMut, BlockEncrypt, BlockEncryptMut, InnerIvInit, Iv, + IvSizeUser, +}; +use cbc::{Decryptor, Encryptor}; +use digest::crypto_common::InnerUser; +use generic_array::GenericArray; + +use super::block::BlockStreamCipher; + +pub struct CbcWrapper { + encryptor: Encryptor, + decryptor: Decryptor, +} + +impl InnerUser for CbcWrapper { + type Inner = C; +} + +impl IvSizeUser for CbcWrapper { + type IvSize = C::BlockSize; +} + +impl BlockStreamCipher for CbcWrapper { + fn encrypt_data(&mut self, data: &mut [u8]) { + for chunk in data.chunks_exact_mut(C::block_size()) { + let mut block: GenericArray = GenericArray::clone_from_slice(chunk); + self.encryptor.encrypt_block_mut(&mut block); + chunk.clone_from_slice(&block); + } + } + + fn decrypt_data(&mut self, data: &mut [u8]) { + for chunk in data.chunks_exact_mut(C::block_size()) { + let mut block = GenericArray::clone_from_slice(chunk); + self.decryptor.decrypt_block_mut(&mut block); + chunk.clone_from_slice(&block); + } + } +} + +impl InnerIvInit for CbcWrapper +where + C: BlockEncryptMut + BlockCipher, +{ + #[inline] + fn inner_iv_init(cipher: C, iv: &Iv) -> Self { + Self { + encryptor: Encryptor::inner_iv_init(cipher.clone(), iv), + decryptor: Decryptor::inner_iv_init(cipher, iv), + } + } +} diff --git a/russh/src/cipher/chacha20poly1305.rs b/russh/src/cipher/chacha20poly1305.rs index cab3eece..605301bc 100644 --- a/russh/src/cipher/chacha20poly1305.rs +++ b/russh/src/cipher/chacha20poly1305.rs @@ -15,6 +15,8 @@ // http://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.chacha20poly1305?annotate=HEAD +use std::convert::TryInto; + use aes::cipher::{BlockSizeUser, StreamCipherSeek}; use byteorder::{BigEndian, ByteOrder}; use chacha20::cipher::{KeyInit, KeyIvInit, StreamCipher}; @@ -94,11 +96,16 @@ impl super::OpeningKey for OpeningKey { fn decrypt_packet_length( &self, sequence_number: u32, - mut encrypted_packet_length: [u8; 4], + encrypted_packet_length: &[u8], ) -> [u8; 4] { + // Fine because of self.packet_length_to_read_for_block_length() + #[allow(clippy::unwrap_used, clippy::indexing_slicing)] + let mut encrypted_packet_length: [u8; 4] = encrypted_packet_length.try_into().unwrap(); + let nonce = make_counter(sequence_number); let mut cipher = ChaCha20Legacy::new(&self.k1, &nonce); cipher.apply_keystream(&mut encrypted_packet_length); + encrypted_packet_length } diff --git a/russh/src/cipher/clear.rs b/russh/src/cipher/clear.rs index ddd552db..68b3df59 100644 --- a/russh/src/cipher/clear.rs +++ b/russh/src/cipher/clear.rs @@ -13,6 +13,8 @@ // limitations under the License. // +use std::convert::TryInto; + use crate::mac::MacAlgorithm; use crate::Error; @@ -48,8 +50,10 @@ impl super::Cipher for Clear { } impl super::OpeningKey for Key { - fn decrypt_packet_length(&self, _seqn: u32, packet_length: [u8; 4]) -> [u8; 4] { - packet_length + fn decrypt_packet_length(&self, _seqn: u32, packet_length: &[u8]) -> [u8; 4] { + // Fine because of self.packet_length_to_read_for_block_length() + #[allow(clippy::unwrap_used, clippy::indexing_slicing)] + packet_length.try_into().unwrap() } fn tag_len(&self) -> usize { diff --git a/russh/src/cipher/gcm.rs b/russh/src/cipher/gcm.rs index f737716c..bf19a59b 100644 --- a/russh/src/cipher/gcm.rs +++ b/russh/src/cipher/gcm.rs @@ -15,8 +15,9 @@ // http://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.chacha20poly1305?annotate=HEAD +use std::convert::TryInto; + use aes_gcm::{AeadCore, AeadInPlace, Aes256Gcm, KeyInit, KeySizeUser}; -use byteorder::{BigEndian, ByteOrder}; use digest::typenum::Unsigned; use generic_array::GenericArray; use rand::RngCore; @@ -84,36 +85,25 @@ pub struct SealingKey { cipher: Aes256Gcm, } -const GCM_COUNTER_OFFSET: u64 = 3; - -fn make_nonce( - nonce: &GenericArray, - sequence_number: u32, -) -> GenericArray { - let mut new_nonce = GenericArray::::default(); - new_nonce.clone_from_slice(nonce); - // Increment the nonce - let i0 = new_nonce.len() - 8; - +fn inc_nonce(nonce: &mut GenericArray) { + let mut carry = 1; #[allow(clippy::indexing_slicing)] // length checked - let ctr = BigEndian::read_u64(&new_nonce[i0..]); - - // GCM requires the counter to start from 1 - #[allow(clippy::indexing_slicing)] // length checked - BigEndian::write_u64( - &mut new_nonce[i0..], - ctr + sequence_number as u64 - GCM_COUNTER_OFFSET, - ); - new_nonce + for i in (0..nonce.len()).rev() { + let n = nonce[i] as u16 + carry; + nonce[i] = n as u8; + carry = n >> 8; + } } impl super::OpeningKey for OpeningKey { fn decrypt_packet_length( &self, _sequence_number: u32, - encrypted_packet_length: [u8; 4], + encrypted_packet_length: &[u8], ) -> [u8; 4] { - encrypted_packet_length + // Fine because of self.packet_length_to_read_for_block_length() + #[allow(clippy::unwrap_used, clippy::indexing_slicing)] + encrypted_packet_length.try_into().unwrap() } fn tag_len(&self) -> usize { @@ -122,7 +112,7 @@ impl super::OpeningKey for OpeningKey { fn open<'a>( &mut self, - sequence_number: u32, + _sequence_number: u32, ciphertext_in_plaintext_out: &'a mut [u8], tag: &[u8], ) -> Result<&'a [u8], Error> { @@ -137,22 +127,23 @@ impl super::OpeningKey for OpeningKey { #[allow(clippy::indexing_slicing)] // length checked buffer.copy_from_slice(&ciphertext_in_plaintext_out[super::PACKET_LENGTH_LEN..]); - let nonce = make_nonce(&self.nonce, sequence_number); - let mut tag_buf = GenericArray::::default(); tag_buf.clone_from_slice(tag); #[allow(clippy::indexing_slicing)] self.cipher .decrypt_in_place_detached( - &nonce, + &self.nonce, &packet_length, &mut ciphertext_in_plaintext_out[super::PACKET_LENGTH_LEN..], &tag_buf, ) .map_err(|_| Error::DecryptionError)?; - Ok(ciphertext_in_plaintext_out) + inc_nonce(&mut self.nonce); + + #[allow(clippy::indexing_slicing)] + Ok(&ciphertext_in_plaintext_out[super::PACKET_LENGTH_LEN..]) } } @@ -182,7 +173,7 @@ impl super::SealingKey for SealingKey { fn seal( &mut self, - sequence_number: u32, + _sequence_number: u32, plaintext_in_ciphertext_out: &mut [u8], tag: &mut [u8], ) { @@ -191,18 +182,17 @@ impl super::SealingKey for SealingKey { #[allow(clippy::indexing_slicing)] // length checked packet_length.clone_from_slice(&plaintext_in_ciphertext_out[..super::PACKET_LENGTH_LEN]); - let nonce = make_nonce(&self.nonce, sequence_number); - #[allow(clippy::indexing_slicing, clippy::unwrap_used)] let tag_out = self .cipher .encrypt_in_place_detached( - &nonce, + &self.nonce, &packet_length, &mut plaintext_in_ciphertext_out[super::PACKET_LENGTH_LEN..], ) .unwrap(); + inc_nonce(&mut self.nonce); tag.clone_from_slice(&tag_out) } } diff --git a/russh/src/cipher/mod.rs b/russh/src/cipher/mod.rs index 1251253d..a474c2da 100644 --- a/russh/src/cipher/mod.rs +++ b/russh/src/cipher/mod.rs @@ -14,26 +14,32 @@ //! //! This module exports cipher names for use with [Preferred]. +use std::borrow::Borrow; use std::collections::HashMap; +use std::convert::TryFrom; use std::fmt::Debug; use std::marker::PhantomData; use std::num::Wrapping; use aes::{Aes128, Aes192, Aes256}; use byteorder::{BigEndian, ByteOrder}; +use cbc::CbcWrapper; use ctr::Ctr128BE; +use des::TdesEde3; +use log::debug; use once_cell::sync::Lazy; use tokio::io::{AsyncRead, AsyncReadExt}; -use log::debug; use crate::mac::MacAlgorithm; use crate::sshbuffer::SSHBuffer; use crate::Error; pub(crate) mod block; +pub(crate) mod cbc; pub(crate) mod chacha20poly1305; pub(crate) mod clear; pub(crate) mod gcm; + use block::SshBlockCipher; use chacha20poly1305::SshChacha20Poly1305Cipher; use clear::Clear; @@ -65,10 +71,18 @@ pub(crate) trait Cipher { /// `clear` pub const CLEAR: Name = Name("clear"); +/// `3des-cbc` +pub const TRIPLE_DES_CBC: Name = Name("3des-cbc"); /// `aes128-ctr` pub const AES_128_CTR: Name = Name("aes128-ctr"); /// `aes192-ctr` pub const AES_192_CTR: Name = Name("aes192-ctr"); +/// `aes128-cbc` +pub const AES_128_CBC: Name = Name("aes128-cbc"); +/// `aes192-cbc` +pub const AES_192_CBC: Name = Name("aes192-cbc"); +/// `aes256-cbc` +pub const AES_256_CBC: Name = Name("aes256-cbc"); /// `aes256-ctr` pub const AES_256_CTR: Name = Name("aes256-ctr"); /// `aes256-gcm@openssh.com` @@ -79,22 +93,45 @@ pub const CHACHA20_POLY1305: Name = Name("chacha20-poly1305@openssh.com"); pub const NONE: Name = Name("none"); static _CLEAR: Clear = Clear {}; +static _3DES_CBC: SshBlockCipher> = SshBlockCipher(PhantomData); static _AES_128_CTR: SshBlockCipher> = SshBlockCipher(PhantomData); static _AES_192_CTR: SshBlockCipher> = SshBlockCipher(PhantomData); static _AES_256_CTR: SshBlockCipher> = SshBlockCipher(PhantomData); static _AES_256_GCM: GcmCipher = GcmCipher {}; +static _AES_128_CBC: SshBlockCipher> = SshBlockCipher(PhantomData); +static _AES_192_CBC: SshBlockCipher> = SshBlockCipher(PhantomData); +static _AES_256_CBC: SshBlockCipher> = SshBlockCipher(PhantomData); static _CHACHA20_POLY1305: SshChacha20Poly1305Cipher = SshChacha20Poly1305Cipher {}; +pub static ALL_CIPHERS: &[&Name] = &[ + &CLEAR, + &NONE, + &TRIPLE_DES_CBC, + &AES_128_CTR, + &AES_192_CTR, + &AES_256_CTR, + &AES_256_GCM, + &AES_128_CBC, + &AES_192_CBC, + &AES_256_CBC, + &CHACHA20_POLY1305, +]; + pub(crate) static CIPHERS: Lazy> = Lazy::new(|| { let mut h: HashMap<&'static Name, &(dyn Cipher + Send + Sync)> = HashMap::new(); h.insert(&CLEAR, &_CLEAR); h.insert(&NONE, &_CLEAR); + h.insert(&TRIPLE_DES_CBC, &_3DES_CBC); h.insert(&AES_128_CTR, &_AES_128_CTR); h.insert(&AES_192_CTR, &_AES_192_CTR); h.insert(&AES_256_CTR, &_AES_256_CTR); h.insert(&AES_256_GCM, &_AES_256_GCM); + h.insert(&AES_128_CBC, &_AES_128_CBC); + h.insert(&AES_192_CBC, &_AES_192_CBC); + h.insert(&AES_256_CBC, &_AES_256_CBC); h.insert(&CHACHA20_POLY1305, &_CHACHA20_POLY1305); + assert_eq!(h.len(), ALL_CIPHERS.len()); h }); @@ -106,6 +143,19 @@ impl AsRef for Name { } } +impl Borrow for &Name { + fn borrow(&self) -> &str { + self.0 + } +} + +impl TryFrom<&str> for Name { + type Error = (); + fn try_from(s: &str) -> Result { + CIPHERS.keys().find(|x| x.0 == s).map(|x| **x).ok_or(()) + } +} + pub(crate) struct CipherPair { pub local_to_remote: Box, pub remote_to_local: Box, @@ -118,7 +168,11 @@ impl Debug for CipherPair { } pub(crate) trait OpeningKey { - fn decrypt_packet_length(&self, seqn: u32, encrypted_packet_length: [u8; 4]) -> [u8; 4]; + fn packet_length_to_read_for_block_length(&self) -> usize { + 4 + } + + fn decrypt_packet_length(&self, seqn: u32, encrypted_packet_length: &[u8]) -> [u8; 4]; fn tag_len(&self) -> usize; @@ -182,7 +236,8 @@ pub(crate) async fn read<'a, R: AsyncRead + Unpin>( cipher: &'a mut (dyn OpeningKey + Send), ) -> Result { if buffer.len == 0 { - let mut len = [0; 4]; + let mut len = vec![0; cipher.packet_length_to_read_for_block_length()]; + stream.read_exact(&mut len).await?; debug!("reading, len = {:?}", len); { @@ -190,8 +245,14 @@ pub(crate) async fn read<'a, R: AsyncRead + Unpin>( buffer.buffer.clear(); buffer.buffer.extend(&len); debug!("reading, seqn = {:?}", seqn); - let len = cipher.decrypt_packet_length(seqn, len); - buffer.len = BigEndian::read_u32(&len) as usize + cipher.tag_len(); + let len = cipher.decrypt_packet_length(seqn, &len); + let len = BigEndian::read_u32(&len) as usize; + + if len > MAXIMUM_PACKET_LEN { + return Err(Error::PacketSize(len)); + } + + buffer.len = len + cipher.tag_len(); debug!("reading, clear len = {:?}", buffer.len); } } @@ -199,7 +260,9 @@ pub(crate) async fn read<'a, R: AsyncRead + Unpin>( buffer.buffer.resize(buffer.len + 4); debug!("read_exact {:?}", buffer.len + 4); #[allow(clippy::indexing_slicing)] // length checked - stream.read_exact(&mut buffer.buffer[4..]).await?; + stream + .read_exact(&mut buffer.buffer[cipher.packet_length_to_read_for_block_length()..]) + .await?; debug!("read_exact done"); let seqn = buffer.seqn.0; let ciphertext_len = buffer.buffer.len() - cipher.tag_len(); @@ -227,5 +290,6 @@ pub(crate) async fn read<'a, R: AsyncRead + Unpin>( pub(crate) const PACKET_LENGTH_LEN: usize = 4; const MINIMUM_PACKET_LEN: usize = 16; +const MAXIMUM_PACKET_LEN: usize = 256 * 1024; const PADDING_LENGTH_LEN: usize = 1; diff --git a/russh/src/client/encrypted.rs b/russh/src/client/encrypted.rs index cc7df2b4..1a3ac7ba 100644 --- a/russh/src/client/encrypted.rs +++ b/russh/src/client/encrypted.rs @@ -14,20 +14,20 @@ // use std::cell::RefCell; use std::convert::TryInto; +use std::num::Wrapping; use log::{debug, error, info, trace, warn}; -use russh_cryptovec::CryptoVec; -use russh_keys::encoding::{Encoding, Reader}; -use russh_keys::key::parse_public_key; -use tokio::sync::mpsc::unbounded_channel; use crate::client::{Handler, Msg, Prompt, Reply, Session}; use crate::key::PubKey; +use crate::keys::encoding::{Encoding, Reader}; +use crate::keys::key::parse_public_key; use crate::negotiation::{Named, Select}; use crate::parsing::{ChannelOpenConfirmation, ChannelType, OpenChannelMessage}; -use crate::session::{Encrypted, EncryptedState, Kex, KexInit}; +use crate::session::{Encrypted, EncryptedState, GlobalRequestResponse, Kex, KexInit}; use crate::{ - auth, msg, negotiation, Channel, ChannelId, ChannelMsg, ChannelOpenFailure, ChannelParams, Sig, + auth, msg, negotiation, strict_kex_violation, Channel, ChannelId, ChannelMsg, + ChannelOpenFailure, ChannelParams, CryptoVec, Sig, }; thread_local! { @@ -36,10 +36,11 @@ thread_local! { impl Session { pub(crate) async fn client_read_encrypted( - mut self, - mut client: H, + &mut self, + client: &mut H, + seqn: &mut Wrapping, buf: &[u8], - ) -> Result<(H, Self), H::Error> { + ) -> Result<(), H::Error> { #[allow(clippy::indexing_slicing)] // length checked { trace!( @@ -55,10 +56,14 @@ impl Session { // If we're not currently re-keying, but buf is a rekey request let kexinit = if let Some(Kex::Init(kexinit)) = enc.rekey.take() { Some(kexinit) - } else if let Some(exchange) = std::mem::replace(&mut enc.exchange, None) { + } else if let Some(exchange) = enc.exchange.take() { Some(KexInit::received_rekey( exchange, - negotiation::Client::read_kex(buf, &self.common.config.as_ref().preferred)?, + negotiation::Client::read_kex( + buf, + &self.common.config.as_ref().preferred, + None, + )?, &enc.session_id, )) } else { @@ -66,6 +71,12 @@ impl Session { }; if let Some(kexinit) = kexinit { + if let Some(ref algo) = kexinit.algo { + if self.common.strict_kex && !algo.strict_kex { + return Err(strict_kex_violation(msg::KEXINIT, 0).into()); + } + } + let dhdone = kexinit.client_parse( self.common.config.as_ref(), &mut *self.common.cipher.local_to_remote, @@ -81,7 +92,7 @@ impl Session { unreachable!() } self.flush()?; - return Ok((client, self)); + return Ok(()); } if let Some(ref mut enc) = self.common.encrypted { @@ -90,18 +101,18 @@ impl Session { return if kexdhdone.names.ignore_guessed { kexdhdone.names.ignore_guessed = false; enc.rekey = Some(Kex::DhDone(kexdhdone)); - Ok((client, self)) + Ok(()) } else if buf.first() == Some(&msg::KEX_ECDH_REPLY) { // We've sent ECDH_INIT, waiting for ECDH_REPLY - let (kex, h) = kexdhdone.server_key_check(true, client, buf).await?; - client = h; + let kex = kexdhdone.server_key_check(true, client, buf).await?; enc.rekey = Some(Kex::Keys(kex)); self.common .cipher .local_to_remote .write(&[msg::NEWKEYS], &mut self.common.write_buffer); self.flush()?; - Ok((client, self)) + self.common.maybe_reset_seqn(); + Ok(()) } else { error!("Wrong packet received"); Err(crate::Error::Inconsistent.into()) @@ -118,15 +129,18 @@ impl Session { enc.flush_all_pending(); let mut pending = std::mem::take(&mut self.pending_reads); for p in pending.drain(..) { - let (h, s) = self.process_packet(client, &p).await?; - self = s; - client = h; + self.process_packet(client, &p).await?; } self.pending_reads = pending; self.pending_len = 0; self.common.newkeys(newkeys); self.flush()?; - return Ok((client, self)); + + if self.common.strict_kex { + *seqn = Wrapping(0); + } + + return Ok(()); } Some(Kex::Init(k)) => { enc.rekey = Some(Kex::Init(k)); @@ -135,7 +149,7 @@ impl Session { return Err(crate::Error::Pending.into()); } self.pending_reads.push(CryptoVec::from_slice(buf)); - return Ok((client, self)); + return Ok(()); } rek => enc.rekey = rek, } @@ -144,10 +158,10 @@ impl Session { } async fn process_packet( - mut self, - client: H, + &mut self, + client: &mut H, buf: &[u8], - ) -> Result<(H, Self), H::Error> { + ) -> Result<(), H::Error> { // If we've successfully read a packet. trace!("process_packet buf = {:?} bytes", buf.len()); trace!("buf = {:?}", buf); @@ -212,15 +226,14 @@ impl Session { .map_err(|_| crate::Error::SendError)?; enc.state = EncryptedState::InitCompression; enc.server_compression.init_decompress(&mut enc.decompress); - return Ok((client, self)); + return Ok(()); } else if buf.first() == Some(&msg::USERAUTH_BANNER) { let mut r = buf.reader(1); let banner = r.read_string().map_err(crate::Error::from)?; return if let Ok(banner) = std::str::from_utf8(banner) { - let (h, s) = client.auth_banner(banner, self).await?; - Ok((h, s)) + client.auth_banner(banner, self).await } else { - Ok((client, self)) + Ok(()) }; } else if buf.first() == Some(&msg::USERAUTH_FAILURE) { debug!("userauth_failure"); @@ -306,8 +319,7 @@ impl Session { }; // write responses enc.client_send_auth_response(&responses)?; - return Ok((client, self)); - } else { + return Ok(()); } // continue with userauth_pk_ok @@ -320,6 +332,14 @@ impl Session { &mut self.common.buffer, )? } + Some(auth_method @ auth::Method::OpenSSHCertificate { .. }) => { + self.common.buffer.clear(); + enc.client_send_signature( + &self.common.auth_user, + &auth_method, + &mut self.common.buffer, + )? + } Some(auth::Method::FuturePublicKey { key }) => { debug!("public key"); self.common.buffer.clear(); @@ -365,20 +385,20 @@ impl Session { if is_authenticated { self.client_read_authenticated(client, buf).await } else { - Ok((client, self)) + Ok(()) } } - fn handle_ext_info(self, client: H, buf: &[u8]) -> Result<(H, Self), H::Error> { + fn handle_ext_info(&mut self, _client: &mut H, buf: &[u8]) -> Result<(), H::Error> { debug!("Received EXT_INFO: {:?}", buf); - Ok((client, self)) + Ok(()) } async fn client_read_authenticated( - mut self, - mut client: H, + &mut self, + client: &mut H, buf: &[u8], - ) -> Result<(H, Self), H::Error> { + ) -> Result<(), H::Error> { match buf.first() { Some(&msg::CHANNEL_OPEN_CONFIRMATION) => { debug!("channel_open_confirmation"); @@ -589,7 +609,7 @@ impl Session { } else { warn!("Received keepalive without reply request!"); } - Ok((client, self)) + Ok(()) } _ => { let wants_reply = r.read_byte().map_err(crate::Error::from)?; @@ -607,7 +627,7 @@ impl Session { std::str::from_utf8(req), wants_reply ); - Ok((client, self)) + Ok(()) } } } @@ -631,6 +651,8 @@ impl Session { new_size -= enc.flush_pending(channel_num) as u32; } if let Some(chan) = self.channels.get(&channel_num) { + *chan.window_size().lock().await = new_size; + let _ = chan.send(ChannelMsg::WindowAdjusted { new_size }); } client.window_adjusted(channel_num, new_size, self).await @@ -657,9 +679,6 @@ impl Session { match r.read_string() { Ok(key) => { let key2 = <&[u8]>::clone(&key); - #[cfg(not(feature = "openssl"))] - let key = parse_public_key(key).map_err(crate::Error::from); - #[cfg(feature = "openssl")] let key = parse_public_key(key, None).map_err(crate::Error::from); match key { @@ -689,7 +708,8 @@ impl Session { push_packet!(enc.write, enc.write.push(msg::REQUEST_FAILURE)) } } - Ok((client, self)) + self.common.received_data = false; + Ok(()) } Some(&msg::CHANNEL_SUCCESS) => { let mut r = buf.reader(1); @@ -723,6 +743,8 @@ impl Session { confirmed: true, wants_reply: false, pending_data: std::collections::VecDeque::new(), + pending_eof: false, + pending_close: false, }; let confirm = || { @@ -736,16 +758,18 @@ impl Session { enc.channels.insert(id, channel); }; - Ok(match &msg.typ { + match &msg.typ { ChannelType::Session => { confirm(); - client.server_channel_open_session(id, self).await? + let channel = self.accept_server_initiated_channel(id, &msg); + client.server_channel_open_session(channel, self).await? } ChannelType::DirectTcpip(d) => { confirm(); + let channel = self.accept_server_initiated_channel(id, &msg); client .server_channel_open_direct_tcpip( - id, + channel, &d.host_to_connect, d.port_to_connect, &d.originator_address, @@ -783,27 +807,104 @@ impl Session { ) .await? } + ChannelType::ForwardedStreamLocal(d) => { + confirm(); + let channel = self.accept_server_initiated_channel(id, &msg); + client + .server_channel_open_forwarded_streamlocal( + channel, + &d.socket_path, + self, + ) + .await?; + } ChannelType::AgentForward => { confirm(); - client.server_channel_open_agent_forward(id, self).await? + let channel = self.accept_server_initiated_channel(id, &msg); + client + .server_channel_open_agent_forward(channel, self) + .await? } ChannelType::Unknown { typ } => { - if client.server_channel_handle_unknown(id, typ) { + if client.should_accept_unknown_server_channel(id, typ).await { confirm(); + let channel = self.accept_server_initiated_channel(id, &msg); + client.server_channel_open_unknown(channel, self).await?; } else { debug!("unknown channel type: {}", String::from_utf8_lossy(typ)); msg.unknown_type(&mut enc.write); } - (client, self) } - }) + }; + Ok(()) } else { Err(crate::Error::Inconsistent.into()) } } - _ => { - info!("Unhandled packet: {:?}", buf); - Ok((client, self)) + Some(&msg::REQUEST_SUCCESS) => { + trace!("Global Request Success"); + match self.open_global_requests.pop_front() { + Some(GlobalRequestResponse::Keepalive) => { + // ignore keepalives + } + Some(GlobalRequestResponse::TcpIpForward(return_channel)) => { + let result = if buf.len() == 1 { + // If a specific port was requested, the reply has no data + Some(0) + } else { + let mut r = buf.reader(1); + match r.read_u32() { + Ok(port) => Some(port), + Err(e) => { + error!("Error parsing port for TcpIpForward request: {e:?}"); + None + } + } + }; + let _ = return_channel.send(result); + } + Some(GlobalRequestResponse::CancelTcpIpForward(return_channel)) => { + let _ = return_channel.send(true); + } + Some(GlobalRequestResponse::StreamLocalForward(return_channel)) => { + let _ = return_channel.send(true); + } + Some(GlobalRequestResponse::CancelStreamLocalForward(return_channel)) => { + let _ = return_channel.send(true); + } + None => { + error!("Received global request failure for unknown request!") + } + } + Ok(()) + } + Some(&msg::REQUEST_FAILURE) => { + trace!("global request failure"); + match self.open_global_requests.pop_front() { + Some(GlobalRequestResponse::Keepalive) => { + // ignore keepalives + } + Some(GlobalRequestResponse::TcpIpForward(return_channel)) => { + let _ = return_channel.send(None); + } + Some(GlobalRequestResponse::CancelTcpIpForward(return_channel)) => { + let _ = return_channel.send(false); + } + Some(GlobalRequestResponse::StreamLocalForward(return_channel)) => { + let _ = return_channel.send(false); + } + Some(GlobalRequestResponse::CancelStreamLocalForward(return_channel)) => { + let _ = return_channel.send(false); + } + None => { + error!("Received global request failure for unknown request!") + } + } + Ok(()) + } + m => { + debug!("unknown message received: {:?}", m); + Ok(()) } } } @@ -813,15 +914,16 @@ impl Session { id: ChannelId, msg: &OpenChannelMessage, ) -> Channel { - let (sender, receiver) = unbounded_channel(); - self.channels.insert(id, sender); - Channel { + let (channel, channel_ref) = Channel::new( id, - sender: self.inbound_channel_sender.clone(), - receiver, - max_packet_size: msg.recipient_maximum_packet_size, - window_size: msg.recipient_window_size, - } + self.inbound_channel_sender.clone(), + msg.recipient_maximum_packet_size, + msg.recipient_window_size, + ); + + self.channels.insert(id, channel_ref); + + channel } pub(crate) fn write_auth_request_if_needed(&mut self, user: &str, meth: auth::Method) -> bool { @@ -888,11 +990,22 @@ impl Encrypted { self.write.extend_ssh_string(b"publickey"); self.write.push(0); // This is a probe - debug!("write_auth_request: {:?}", key.name()); + debug!("write_auth_request: key - {:?}", key.name()); self.write.extend_ssh_string(key.name().as_bytes()); key.push_to(&mut self.write); true } + auth::Method::OpenSSHCertificate { ref cert, .. } => { + self.write.extend_ssh_string(user.as_bytes()); + self.write.extend_ssh_string(b"ssh-connection"); + self.write.extend_ssh_string(b"publickey"); + self.write.push(0); // This is a probe + + debug!("write_auth_request: cert - {:?}", cert.name()); + self.write.extend_ssh_string(cert.name().as_bytes()); + cert.push_to(&mut self.write); + true + } auth::Method::FuturePublicKey { ref key, .. } => { self.write.extend_ssh_string(user.as_bytes()); self.write.extend_ssh_string(b"ssh-connection"); @@ -931,7 +1044,7 @@ impl Encrypted { buffer.extend_ssh_string(b"ssh-connection"); buffer.extend_ssh_string(b"publickey"); buffer.push(1); - buffer.extend_ssh_string(key.name().as_bytes()); + buffer.extend_ssh_string(key.name().as_bytes()); // TODO key.push_to(buffer); i0 } @@ -943,7 +1056,7 @@ impl Encrypted { buffer: &mut CryptoVec, ) -> Result<(), crate::Error> { match method { - auth::Method::PublicKey { ref key } => { + auth::Method::PublicKey { ref key, .. } => { let i0 = self.client_make_to_sign(user, key.as_ref(), buffer); // Extend with self-signature. key.add_self_signature(buffer)?; @@ -952,6 +1065,15 @@ impl Encrypted { self.write.extend(&buffer[i0..]); }) } + auth::Method::OpenSSHCertificate { ref key, ref cert } => { + let i0 = self.client_make_to_sign(user, cert, buffer); + // Extend with self-signature. + key.add_self_signature(buffer)?; + push_packet!(self.write, { + #[allow(clippy::indexing_slicing)] // length checked + self.write.extend(&buffer[i0..]); + }) + } _ => {} } Ok(()) diff --git a/russh/src/client/kex.rs b/russh/src/client/kex.rs index afd5ae62..92de368a 100644 --- a/russh/src/client/kex.rs +++ b/russh/src/client/kex.rs @@ -21,7 +21,7 @@ impl KexInit { // read algorithms from packet. debug!("extending {:?}", &self.exchange.server_kex_init[..]); self.exchange.server_kex_init.extend(buf); - negotiation::Client::read_kex(buf, &config.preferred)? + negotiation::Client::read_kex(buf, &config.preferred, None)? }; debug!("algo = {:?}", algo); debug!("write = {:?}", &write_buffer.buffer[..]); @@ -69,7 +69,7 @@ impl KexInit { write_buffer: &mut SSHBuffer, ) -> Result<(), crate::Error> { self.exchange.client_kex_init.clear(); - negotiation::write_kex(&config.preferred, &mut self.exchange.client_kex_init, false)?; + negotiation::write_kex(&config.preferred, &mut self.exchange.client_kex_init, None)?; self.sent = true; cipher.write(&self.exchange.client_kex_init, write_buffer); Ok(()) diff --git a/russh/src/client/mod.rs b/russh/src/client/mod.rs index 5cd442bb..38ec04c5 100644 --- a/russh/src/client/mod.rs +++ b/russh/src/client/mod.rs @@ -26,57 +26,18 @@ //! The [Session](client::Session) is passed to the [Handler](client::Handler) //! when the client receives data. //! -//! ```no_run -//! use async_trait::async_trait; -//! use std::sync::Arc; -//! use russh::*; -//! use russh::server::{Auth, Session}; -//! use russh_keys::*; -//! use futures::Future; -//! use std::io::Read; +//! Check out the following examples: //! -//! struct Client { -//! } -//! -//! #[async_trait] -//! impl client::Handler for Client { -//! type Error = anyhow::Error; -//! -//! async fn check_server_key(self, server_public_key: &key::PublicKey) -> Result<(Self, bool), Self::Error> { -//! println!("check_server_key: {:?}", server_public_key); -//! Ok((self, true)) -//! } -//! -//! async fn data(self, channel: ChannelId, data: &[u8], session: client::Session) -> Result<(Self, client::Session), Self::Error> { -//! println!("data on channel {:?}: {:?}", channel, std::str::from_utf8(data)); -//! Ok((self, session)) -//! } -//! } -//! -//! #[tokio::main] -//! async fn main() { -//! let config = russh::client::Config::default(); -//! let config = Arc::new(config); -//! let sh = Client{}; -//! -//! let key = russh_keys::key::KeyPair::generate_ed25519().unwrap(); -//! let mut agent = russh_keys::agent::client::AgentClient::connect_env().await.unwrap(); -//! agent.add_identity(&key, &[]).await.unwrap(); -//! let mut session = russh::client::connect(config, ("127.0.0.1", 22), sh).await.unwrap(); -//! if session.authenticate_future(std::env::var("USER").unwrap_or("user".to_owned()), key.clone_public_key().unwrap(), agent).await.1.unwrap() { -//! let mut channel = session.channel_open_session().await.unwrap(); -//! channel.data(&b"Hello, world!"[..]).await.unwrap(); -//! if let Some(msg) = channel.wait().await { -//! println!("{:?}", msg) -//! } -//! } -//! } -//! ``` +//! * [Client that connects to a server, runs a command and prints its output](https://github.com/warp-tech/russh/blob/main/russh/examples/client_exec_simple.rs) +//! * [Client that connects to a server, runs a command in a PTY and provides interactive input/output](https://github.com/warp-tech/russh/blob/main/russh/examples/client_exec_interactive.rs) +//! * [SFTP client (with `russh-sftp`)](https://github.com/warp-tech/russh/blob/main/russh/examples/sftp_client.rs) //! //! [Session]: client::Session use std::cell::RefCell; -use std::collections::HashMap; +use std::collections::{HashMap, VecDeque}; +use std::convert::TryInto; +use std::num::Wrapping; use std::pin::Pin; use std::sync::Arc; @@ -84,26 +45,30 @@ use async_trait::async_trait; use futures::task::{Context, Poll}; use futures::Future; use log::{debug, error, info, trace}; -use russh_cryptovec::CryptoVec; -use russh_keys::encoding::Reader; -#[cfg(feature = "openssl")] -use russh_keys::key::SignatureHash; -use russh_keys::key::{self, parse_public_key, PublicKey}; -use tokio; -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; +use ssh_key::Certificate; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf}; use tokio::net::{TcpStream, ToSocketAddrs}; use tokio::pin; use tokio::sync::mpsc::{ channel, unbounded_channel, Receiver, Sender, UnboundedReceiver, UnboundedSender, }; +use tokio::sync::{oneshot, Mutex}; -use crate::channels::{Channel, ChannelMsg}; +use crate::channels::{Channel, ChannelMsg, ChannelRef}; use crate::cipher::{self, clear, CipherPair, OpeningKey}; use crate::key::PubKey; -use crate::session::{CommonSession, EncryptedState, Exchange, Kex, KexDhDone, KexInit, NewKeys}; +use crate::keys::encoding::Reader; +use crate::keys::key::{self, parse_public_key, PublicKey, SignatureHash}; +use crate::session::{ + CommonSession, EncryptedState, Exchange, GlobalRequestResponse, Kex, KexDhDone, KexInit, + NewKeys, +}; use crate::ssh_read::SshRead; use crate::sshbuffer::{SSHBuffer, SshId}; -use crate::{auth, msg, negotiation, ChannelId, ChannelOpenFailure, Disconnect, Limits, Sig}; +use crate::{ + auth, msg, negotiation, strict_kex_violation, ChannelId, ChannelOpenFailure, CryptoVec, + Disconnect, Limits, Sig, +}; mod encrypted; mod kex; @@ -118,14 +83,17 @@ pub struct Session { common: CommonSession>, receiver: Receiver, sender: UnboundedSender, - channels: HashMap>, + channels: HashMap, target_window_size: u32, pending_reads: Vec, pending_len: u32, inbound_channel_sender: Sender, inbound_channel_receiver: Receiver, + open_global_requests: VecDeque, } +const STRICT_KEX_MSG_ORDER: &[u8] = &[msg::KEXINIT, msg::KEX_ECDH_REPLY, msg::NEWKEYS]; + impl Drop for Session { fn drop(&mut self) { debug!("drop session") @@ -162,34 +130,46 @@ pub enum Msg { data: CryptoVec, }, ChannelOpenSession { - sender: UnboundedSender, + channel_ref: ChannelRef, }, ChannelOpenX11 { originator_address: String, originator_port: u32, - sender: UnboundedSender, + channel_ref: ChannelRef, }, ChannelOpenDirectTcpIp { host_to_connect: String, port_to_connect: u32, originator_address: String, originator_port: u32, - sender: UnboundedSender, + channel_ref: ChannelRef, }, ChannelOpenDirectStreamLocal { socket_path: String, - sender: UnboundedSender, + channel_ref: ChannelRef, }, TcpIpForward { - want_reply: bool, + /// Provide a channel for the reply result to request a reply from the server + reply_channel: Option>>, address: String, port: u32, }, CancelTcpIpForward { - want_reply: bool, + /// Provide a channel for the reply result to request a reply from the server + reply_channel: Option>, address: String, port: u32, }, + StreamLocalForward { + /// Provide a channel for the reply result to request a reply from the server + reply_channel: Option>, + socket_path: String, + }, + CancelStreamLocalForward { + /// Provide a channel for the reply result to request a reply from the server + reply_channel: Option>, + socket_path: String, + }, Close { id: ChannelId, }, @@ -224,6 +204,19 @@ pub struct Prompt { pub echo: bool, } +#[derive(Debug)] +pub struct RemoteDisconnectInfo { + pub reason_code: crate::Disconnect, + pub message: String, + pub lang_tag: String, +} + +#[derive(Debug)] +pub enum DisconnectReason + Send> { + ReceivedDisconnect(RemoteDisconnectInfo), + Error(E), +} + /// Handle to a session, used to send messages to a client outside of /// the request/response cycle. pub struct Handle { @@ -370,6 +363,24 @@ impl Handle { self.wait_recv_reply().await } + /// Perform public OpenSSH Certificate-based SSH authentication + pub async fn authenticate_openssh_cert>( + &mut self, + user: U, + key: Arc, + cert: Certificate, + ) -> Result { + let user = user.into(); + self.sender + .send(Msg::Authenticate { + user, + method: auth::Method::OpenSSHCertificate { key, cert }, + }) + .await + .map_err(|_| crate::Error::SendError)?; + self.wait_recv_reply().await + } + /// Authenticate using a custom method that implements the /// [`Signer`][auth::Signer] trait. Currently, this crate only provides an /// implementation for an [SSH @@ -418,6 +429,7 @@ impl Handle { async fn wait_channel_confirmation( &self, mut receiver: UnboundedReceiver, + window_size_ref: Arc>, ) -> Result, crate::Error> { loop { match receiver.recv().await { @@ -426,12 +438,14 @@ impl Handle { max_packet_size, window_size, }) => { + *window_size_ref.lock().await = window_size; + return Ok(Channel { id, sender: self.sender.clone(), receiver, max_packet_size, - window_size, + window_size: window_size_ref, }); } Some(ChannelMsg::OpenFailure(reason)) => { @@ -454,11 +468,15 @@ impl Handle { /// `confirmed` field of the corresponding `Channel`. pub async fn channel_open_session(&self) -> Result, crate::Error> { let (sender, receiver) = unbounded_channel(); + let channel_ref = ChannelRef::new(sender); + let window_size_ref = channel_ref.window_size().clone(); + self.sender - .send(Msg::ChannelOpenSession { sender }) + .send(Msg::ChannelOpenSession { channel_ref }) .await .map_err(|_| crate::Error::SendError)?; - self.wait_channel_confirmation(receiver).await + self.wait_channel_confirmation(receiver, window_size_ref) + .await } /// Request an X11 channel, on which the X11 protocol may be tunneled. @@ -468,15 +486,19 @@ impl Handle { originator_port: u32, ) -> Result, crate::Error> { let (sender, receiver) = unbounded_channel(); + let channel_ref = ChannelRef::new(sender); + let window_size_ref = channel_ref.window_size().clone(); + self.sender .send(Msg::ChannelOpenX11 { originator_address: originator_address.into(), originator_port, - sender, + channel_ref, }) .await .map_err(|_| crate::Error::SendError)?; - self.wait_channel_confirmation(receiver).await + self.wait_channel_confirmation(receiver, window_size_ref) + .await } /// Open a TCP/IP forwarding channel. This is usually done when a @@ -495,17 +517,21 @@ impl Handle { originator_port: u32, ) -> Result, crate::Error> { let (sender, receiver) = unbounded_channel(); + let channel_ref = ChannelRef::new(sender); + let window_size_ref = channel_ref.window_size().clone(); + self.sender .send(Msg::ChannelOpenDirectTcpIp { host_to_connect: host_to_connect.into(), port_to_connect, originator_address: originator_address.into(), originator_port, - sender, + channel_ref, }) .await .map_err(|_| crate::Error::SendError)?; - self.wait_channel_confirmation(receiver).await + self.wait_channel_confirmation(receiver, window_size_ref) + .await } pub async fn channel_open_direct_streamlocal>( @@ -513,49 +539,120 @@ impl Handle { socket_path: S, ) -> Result, crate::Error> { let (sender, receiver) = unbounded_channel(); + let channel_ref = ChannelRef::new(sender); + let window_size_ref = channel_ref.window_size().clone(); + self.sender .send(Msg::ChannelOpenDirectStreamLocal { socket_path: socket_path.into(), - sender, + channel_ref, }) .await .map_err(|_| crate::Error::SendError)?; - self.wait_channel_confirmation(receiver).await + self.wait_channel_confirmation(receiver, window_size_ref) + .await } + /// Requests the server to open a TCP/IP forward channel + /// + /// If port == 0 the server will choose a port that will be returned, returns 0 otherwise pub async fn tcpip_forward>( &mut self, address: A, port: u32, - ) -> Result { + ) -> Result { + let (reply_send, reply_recv) = oneshot::channel(); self.sender .send(Msg::TcpIpForward { - want_reply: true, + reply_channel: Some(reply_send), address: address.into(), port, }) .await .map_err(|_| crate::Error::SendError)?; - if port == 0 { - self.wait_recv_reply().await?; + + match reply_recv.await { + Ok(Some(port)) => Ok(port), + Ok(None) => Err(crate::Error::RequestDenied), + Err(e) => { + error!("Unable to receive TcpIpForward result: {e:?}"); + Err(crate::Error::Disconnect) + } } - Ok(true) } + // Requests the server to close a TCP/IP forward channel pub async fn cancel_tcpip_forward>( &self, address: A, port: u32, - ) -> Result { + ) -> Result<(), crate::Error> { + let (reply_send, reply_recv) = oneshot::channel(); self.sender .send(Msg::CancelTcpIpForward { - want_reply: true, + reply_channel: Some(reply_send), address: address.into(), port, }) .await .map_err(|_| crate::Error::SendError)?; - Ok(true) + + match reply_recv.await { + Ok(true) => Ok(()), + Ok(false) => Err(crate::Error::RequestDenied), + Err(e) => { + error!("Unable to receive CancelTcpIpForward result: {e:?}"); + Err(crate::Error::Disconnect) + } + } + } + + // Requests the server to open a UDS forward channel + pub async fn streamlocal_forward>( + &mut self, + socket_path: A, + ) -> Result<(), crate::Error> { + let (reply_send, reply_recv) = oneshot::channel(); + self.sender + .send(Msg::StreamLocalForward { + reply_channel: Some(reply_send), + socket_path: socket_path.into(), + }) + .await + .map_err(|_| crate::Error::SendError)?; + + match reply_recv.await { + Ok(true) => Ok(()), + Ok(false) => Err(crate::Error::RequestDenied), + Err(e) => { + error!("Unable to receive StreamLocalForward result: {e:?}"); + Err(crate::Error::Disconnect) + } + } + } + + // Requests the server to close a UDS forward channel + pub async fn cancel_streamlocal_forward>( + &self, + socket_path: A, + ) -> Result<(), crate::Error> { + let (reply_send, reply_recv) = oneshot::channel(); + self.sender + .send(Msg::CancelStreamLocalForward { + reply_channel: Some(reply_send), + socket_path: socket_path.into(), + }) + .await + .map_err(|_| crate::Error::SendError)?; + + match reply_recv.await { + Ok(true) => Ok(()), + Ok(false) => Err(crate::Error::RequestDenied), + Err(e) => { + error!("Unable to receive CancelStreamLocalForward result: {e:?}"); + Err(crate::Error::Disconnect) + } + } } /// Sends a disconnect message. @@ -672,15 +769,21 @@ where wants_reply: false, disconnected: false, buffer: CryptoVec::new(), + strict_kex: false, + alive_timeouts: 0, + received_data: false, + remote_sshid: sshid.into(), }, session_receiver, session_sender, ); session.read_ssh_id(sshid)?; - let (encrypted_signal, encrypted_recv) = tokio::sync::oneshot::channel(); - let join = tokio::spawn(session.run(stream, handler, Some(encrypted_signal))); + let (kex_done_signal, kex_done_signal_rx) = oneshot::channel(); + let join = tokio::spawn(session.run(stream, handler, Some(kex_done_signal))); - if encrypted_recv.await.is_err() { + if kex_done_signal_rx.await.is_err() { + // kex_done_signal Sender is dropped when the session + // fails before a succesful key exchange join.await.map_err(crate::Error::Join)??; return Err(H::Error::from(crate::Error::Disconnect)); } @@ -720,42 +823,98 @@ impl Session { channels: HashMap::new(), pending_reads: Vec::new(), pending_len: 0, + open_global_requests: VecDeque::new(), } } async fn run( mut self, - mut stream: SshRead, + stream: SshRead, mut handler: H, - mut encrypted_signal: Option>, + mut kex_done_signal: Option>, ) -> Result<(), H::Error> { + let (stream_read, mut stream_write) = stream.split(); + let result = self + .run_inner( + stream_read, + &mut stream_write, + &mut handler, + &mut kex_done_signal, + ) + .await; + trace!("disconnected"); + self.receiver.close(); + self.inbound_channel_receiver.close(); + stream_write.shutdown().await.map_err(crate::Error::from)?; + match result { + Ok(v) => { + handler + .disconnected(DisconnectReason::ReceivedDisconnect(v)) + .await?; + Ok(()) + } + Err(e) => { + if kex_done_signal.is_some() { + // The kex signal has not been consumed yet, + // so we can send return the concrete error to be propagated + // into the JoinHandle and returned from `connect_stream` + Err(e) + } else { + // The kex signal has been consumed, so no one is + // awaiting the result of this coroutine + // We're better off passing the error into the Handler + handler.disconnected(DisconnectReason::Error(e)).await?; + Err(H::Error::from(crate::Error::Disconnect)) + } + } + } + } + + async fn run_inner( + &mut self, + stream_read: SshRead>, + stream_write: &mut WriteHalf, + handler: &mut H, + kex_done_signal: &mut Option>, + ) -> Result { + let mut result: Result = + Err(crate::Error::Disconnect.into()); self.flush()?; if !self.common.write_buffer.buffer.is_empty() { debug!("writing {:?} bytes", self.common.write_buffer.buffer.len()); - stream + stream_write .write_all(&self.common.write_buffer.buffer) .await .map_err(crate::Error::from)?; - stream.flush().await.map_err(crate::Error::from)?; + stream_write.flush().await.map_err(crate::Error::from)?; } self.common.write_buffer.buffer.clear(); let mut decomp = CryptoVec::new(); - let (stream_read, mut stream_write) = stream.split(); let buffer = SSHBuffer::new(); // Allow handing out references to the cipher let mut opening_cipher = Box::new(clear::Key) as Box; std::mem::swap(&mut opening_cipher, &mut self.common.cipher.remote_to_local); + let keepalive_timer = + crate::future_or_pending(self.common.config.keepalive_interval, tokio::time::sleep); + pin!(keepalive_timer); + + let inactivity_timer = + crate::future_or_pending(self.common.config.inactivity_timeout, tokio::time::sleep); + pin!(inactivity_timer); + let reading = start_reading(stream_read, buffer, opening_cipher); pin!(reading); #[allow(clippy::panic)] // false positive in select! macro while !self.common.disconnected { + self.common.received_data = false; + let mut sent_keepalive = false; tokio::select! { r = &mut reading => { - let (stream_read, buffer, mut opening_cipher) = match r { + let (stream_read, mut buffer, mut opening_cipher) = match r { Ok((_, stream_read, buffer, opening_cipher)) => (stream_read, buffer, opening_cipher), Err(e) => return Err(e.into()) }; @@ -783,17 +942,29 @@ impl Session { if !buf.is_empty() { #[allow(clippy::indexing_slicing)] // length checked if buf[0] == crate::msg::DISCONNECT { - break; - } else if buf[0] > 4 { - let (h, s) = reply(self, handler, &mut encrypted_signal, buf).await?; - handler = h; - self = s; + result = self.process_disconnect(buf); + } else { + self.common.received_data = true; + reply(self, handler, kex_done_signal, &mut buffer.seqn, buf).await?; } } std::mem::swap(&mut opening_cipher, &mut self.common.cipher.remote_to_local); reading.set(start_reading(stream_read, buffer, opening_cipher)); } + () = &mut keepalive_timer => { + if self.common.config.keepalive_max != 0 && self.common.alive_timeouts > self.common.config.keepalive_max { + debug!("Timeout, server not responding to keepalives"); + return Err(crate::Error::KeepaliveTimeout.into()); + } + self.common.alive_timeouts = self.common.alive_timeouts.saturating_add(1); + self.send_keepalive(true); + sent_keepalive = true; + } + () = &mut inactivity_timer => { + debug!("timeout"); + return Err(crate::Error::InactivityTimeout.into()); + } msg = self.receiver.recv(), if !self.is_rekeying() => { match msg { Some(msg) => self.handle_msg(msg)?, @@ -825,7 +996,8 @@ impl Session { } } } - } + }; + self.flush()?; if !self.common.write_buffer.buffer.is_empty() { trace!( @@ -845,12 +1017,56 @@ impl Session { enc.state = EncryptedState::Authenticated; } } + + if self.common.received_data { + // Reset the number of failed keepalive attempts. We don't + // bother detecting keepalive response messages specifically + // (OpenSSH_9.6p1 responds with REQUEST_FAILURE aka 82). Instead + // we assume that the server is still alive if we receive any + // data from it. + self.common.alive_timeouts = 0; + } + if self.common.received_data || sent_keepalive { + if let (futures::future::Either::Right(ref mut sleep), Some(d)) = ( + keepalive_timer.as_mut().as_pin_mut(), + self.common.config.keepalive_interval, + ) { + sleep.as_mut().reset(tokio::time::Instant::now() + d); + } + } + if !sent_keepalive { + if let (futures::future::Either::Right(ref mut sleep), Some(d)) = ( + inactivity_timer.as_mut().as_pin_mut(), + self.common.config.inactivity_timeout, + ) { + sleep.as_mut().reset(tokio::time::Instant::now() + d); + } + } } - debug!("disconnected"); - if self.common.disconnected { - stream_write.shutdown().await.map_err(crate::Error::from)?; - } - Ok(()) + + result + } + + fn process_disconnect + Send>( + &mut self, + buf: &[u8], + ) -> Result { + self.common.disconnected = true; + let mut reader = buf.reader(1); + + let reason_code = reader.read_u32().map_err(crate::Error::from)?.try_into()?; + let message = std::str::from_utf8(reader.read_string().map_err(crate::Error::from)?) + .map_err(crate::Error::from)? + .to_owned(); + let lang_tag = std::str::from_utf8(reader.read_string().map_err(crate::Error::from)?) + .map_err(crate::Error::from)? + .to_owned(); + + Ok(RemoteDisconnectInfo { + reason_code, + message, + lang_tag, + }) } fn handle_msg(&mut self, msg: Msg) -> Result<(), crate::Error> { @@ -860,24 +1076,24 @@ impl Session { } Msg::Signed { .. } => {} Msg::AuthInfoResponse { .. } => {} - Msg::ChannelOpenSession { sender } => { + Msg::ChannelOpenSession { channel_ref } => { let id = self.channel_open_session()?; - self.channels.insert(id, sender); + self.channels.insert(id, channel_ref); } Msg::ChannelOpenX11 { originator_address, originator_port, - sender, + channel_ref, } => { let id = self.channel_open_x11(&originator_address, originator_port)?; - self.channels.insert(id, sender); + self.channels.insert(id, channel_ref); } Msg::ChannelOpenDirectTcpIp { host_to_connect, port_to_connect, originator_address, originator_port, - sender, + channel_ref, } => { let id = self.channel_open_direct_tcpip( &host_to_connect, @@ -885,25 +1101,33 @@ impl Session { &originator_address, originator_port, )?; - self.channels.insert(id, sender); + self.channels.insert(id, channel_ref); } Msg::ChannelOpenDirectStreamLocal { socket_path, - sender, + channel_ref, } => { let id = self.channel_open_direct_streamlocal(&socket_path)?; - self.channels.insert(id, sender); + self.channels.insert(id, channel_ref); } Msg::TcpIpForward { - want_reply, + reply_channel, address, port, - } => self.tcpip_forward(want_reply, &address, port), + } => self.tcpip_forward(reply_channel, &address, port), Msg::CancelTcpIpForward { - want_reply, + reply_channel, address, port, - } => self.cancel_tcpip_forward(want_reply, &address, port), + } => self.cancel_tcpip_forward(reply_channel, &address, port), + Msg::StreamLocalForward { + reply_channel, + socket_path, + } => self.streamlocal_forward(reply_channel, &socket_path), + Msg::CancelStreamLocalForward { + reply_channel, + socket_path, + } => self.cancel_streamlocal_forward(reply_channel, &socket_path), Msg::Disconnect { reason, description, @@ -1041,7 +1265,7 @@ impl Session { )? { info!("Re-exchanging keys"); if enc.rekey.is_none() { - if let Some(exchange) = std::mem::replace(&mut enc.exchange, None) { + if let Some(exchange) = enc.exchange.take() { let mut kexinit = KexInit::initiate_rekey(exchange, &enc.session_id); kexinit.client_write( self.common.config.as_ref(), @@ -1075,22 +1299,19 @@ impl KexDhDone { async fn server_key_check( mut self, rekey: bool, - mut handler: H, + handler: &mut H, buf: &[u8], - ) -> Result<(NewKeys, H), H::Error> { + ) -> Result { let mut reader = buf.reader(1); let pubkey = reader.read_string().map_err(crate::Error::from)?; // server public key. let pubkey = parse_public_key( pubkey, - #[cfg(feature = "openssl")] SignatureHash::from_rsa_hostkey_algo(self.names.key.0.as_bytes()), ) .map_err(crate::Error::from)?; debug!("server_public_Key: {:?}", pubkey); if !rekey { - let ret = handler.check_server_key(&pubkey).await?; - handler = ret.0; - let check = ret.1; + let check = handler.check_server_key(&pubkey).await?; if !check { return Err(crate::Error::UnknownKey.into()); } @@ -1121,7 +1342,7 @@ impl KexDhDone { debug!("sig_type: {:?}", sig_type); sig_reader.read_string().map_err(crate::Error::from)? }; - use russh_keys::key::Verify; + use crate::keys::key::Verify; debug!("signature: {:?}", signature); if !pubkey.verify_server_auth(hash.as_ref(), signature) { debug!("wrong server sig"); @@ -1131,17 +1352,33 @@ impl KexDhDone { }; let mut newkeys = self.compute_keys(hash, false)?; newkeys.sent = true; - Ok((newkeys, handler)) + Ok(newkeys) }) } } async fn reply( - mut session: Session, - mut handler: H, - sender: &mut Option>, + session: &mut Session, + handler: &mut H, + kex_done_signal: &mut Option>, + seqn: &mut Wrapping, buf: &[u8], -) -> Result<(H, Session), H::Error> { +) -> Result<(), H::Error> { + if let Some(message_type) = buf.first() { + if session.common.strict_kex && session.common.encrypted.is_none() { + let seqno = seqn.0 - 1; // was incremented after read() + if let Some(expected) = STRICT_KEX_MSG_ORDER.get(seqno as usize) { + if message_type != expected { + return Err(strict_kex_violation(*message_type, seqno as usize).into()); + } + } + } + + if [msg::IGNORE, msg::UNIMPLEMENTED, msg::DEBUG].contains(message_type) { + return Ok(()); + } + } + match session.common.kex.take() { Some(Kex::Init(kexinit)) => { if kexinit.algo.is_some() @@ -1155,13 +1392,18 @@ async fn reply( &mut session.common.write_buffer, )?; + // seqno has already been incremented after read() + if done.names.strict_kex && seqn.0 != 1 { + return Err(strict_kex_violation(msg::KEXINIT, seqn.0 as usize - 1).into()); + } + if done.kex.skip_exchange() { session.common.encrypted( - initial_encrypted_state(&session), + initial_encrypted_state(session), done.compute_keys(CryptoVec::new(), false)?, ); - if let Some(sender) = sender.take() { + if let Some(sender) = kex_done_signal.take() { sender.send(()).unwrap_or(()); } } else { @@ -1169,17 +1411,17 @@ async fn reply( } session.flush()?; } - Ok((handler, session)) + Ok(()) } Some(Kex::DhDone(mut kexdhdone)) => { if kexdhdone.names.ignore_guessed { kexdhdone.names.ignore_guessed = false; session.common.kex = Some(Kex::DhDone(kexdhdone)); - Ok((handler, session)) + Ok(()) } else if buf.first() == Some(&msg::KEX_ECDH_REPLY) { // We've sent ECDH_INIT, waiting for ECDH_REPLY - let (kex, h) = kexdhdone.server_key_check(false, handler, buf).await?; - handler = h; + let kex = kexdhdone.server_key_check(false, handler, buf).await?; + session.common.strict_kex = session.common.strict_kex || kex.names.strict_kex; session.common.kex = Some(Kex::Keys(kex)); session .common @@ -1187,7 +1429,8 @@ async fn reply( .local_to_remote .write(&[msg::NEWKEYS], &mut session.common.write_buffer); session.flush()?; - Ok((handler, session)) + session.common.maybe_reset_seqn(); + Ok(()) } else { error!("Wrong packet received"); Err(crate::Error::Inconsistent.into()) @@ -1198,20 +1441,23 @@ async fn reply( if buf.first() != Some(&msg::NEWKEYS) { return Err(crate::Error::Kex.into()); } - if let Some(sender) = sender.take() { + if let Some(sender) = kex_done_signal.take() { sender.send(()).unwrap_or(()); } session .common - .encrypted(initial_encrypted_state(&session), newkeys); + .encrypted(initial_encrypted_state(session), newkeys); // Ok, NEWKEYS received, now encrypted. - Ok((handler, session)) + if session.common.strict_kex { + *seqn = Wrapping(0); + } + Ok(()) } Some(kex) => { session.common.kex = Some(kex); - Ok((handler, session)) + Ok(()) } - None => session.client_read_encrypted(handler, buf).await, + None => session.client_read_encrypted(handler, seqn, buf).await, } } @@ -1241,6 +1487,10 @@ pub struct Config { pub preferred: negotiation::Preferred, /// Time after which the connection is garbage-collected. pub inactivity_timeout: Option, + /// If nothing is received from the server for this amount of time, send a keepalive message. + pub keepalive_interval: Option, + /// If this many keepalives have been sent without reply, close the connection. + pub keepalive_max: usize, /// Whether to expect and wait for an authentication call. pub anonymous: bool, } @@ -1258,6 +1508,8 @@ impl Default for Config { maximum_packet_size: 32768, preferred: Default::default(), inactivity_timeout: None, + keepalive_interval: None, + keepalive_max: 3, anonymous: false, } } @@ -1270,21 +1522,19 @@ impl Default for Config { #[async_trait] pub trait Handler: Sized + Send { - type Error: From + Send; + type Error: From + Send + core::fmt::Debug; /// Called when the server sends us an authentication banner. This /// is usually meant to be shown to the user, see /// [RFC4252](https://tools.ietf.org/html/rfc4252#section-5.4) for /// more details. - /// - /// The returned Boolean is ignored. #[allow(unused_variables)] async fn auth_banner( - self, + &mut self, banner: &str, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> Result<(), Self::Error> { + Ok(()) } /// Called to check the server's public key. This is a very important @@ -1292,10 +1542,10 @@ pub trait Handler: Sized + Send { /// implementation rejects all keys. #[allow(unused_variables)] async fn check_server_key( - self, + &mut self, server_public_key: &key::PublicKey, - ) -> Result<(Self, bool), Self::Error> { - Ok((self, false)) + ) -> Result { + Ok(false) } /// Called when the server confirmed our request to open a @@ -1303,134 +1553,160 @@ pub trait Handler: Sized + Send { /// message (this library panics otherwise). #[allow(unused_variables)] async fn channel_open_confirmation( - self, + &mut self, id: ChannelId, max_packet_size: u32, window_size: u32, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> Result<(), Self::Error> { + Ok(()) } /// Called when the server signals success. #[allow(unused_variables)] async fn channel_success( - self, + &mut self, channel: ChannelId, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> Result<(), Self::Error> { + Ok(()) } /// Called when the server signals failure. #[allow(unused_variables)] async fn channel_failure( - self, + &mut self, channel: ChannelId, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> Result<(), Self::Error> { + Ok(()) } /// Called when the server closes a channel. #[allow(unused_variables)] async fn channel_close( - self, + &mut self, channel: ChannelId, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> Result<(), Self::Error> { + Ok(()) } /// Called when the server sends EOF to a channel. #[allow(unused_variables)] async fn channel_eof( - self, + &mut self, channel: ChannelId, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> Result<(), Self::Error> { + Ok(()) } /// Called when the server rejected our request to open a channel. #[allow(unused_variables)] async fn channel_open_failure( - self, + &mut self, channel: ChannelId, reason: ChannelOpenFailure, description: &str, language: &str, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> Result<(), Self::Error> { + Ok(()) } /// Called when the server opens a channel for a new remote port forwarding connection #[allow(unused_variables)] async fn server_channel_open_forwarded_tcpip( - self, + &mut self, channel: Channel, connected_address: &str, connected_port: u32, originator_address: &str, originator_port: u32, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> Result<(), Self::Error> { + Ok(()) + } + + // Called when the server opens a channel for a new remote UDS forwarding connection + #[allow(unused_variables)] + async fn server_channel_open_forwarded_streamlocal( + &mut self, + channel: Channel, + socket_path: &str, + session: &mut Session, + ) -> Result<(), Self::Error> { + Ok(()) } /// Called when the server opens an agent forwarding channel #[allow(unused_variables)] async fn server_channel_open_agent_forward( - self, - channel: ChannelId, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + &mut self, + channel: Channel, + session: &mut Session, + ) -> Result<(), Self::Error> { + Ok(()) } - /// Called when the server gets an unknown channel. It may return `true`, - /// if the channel of unknown type should be handled. If it returns `false`, - /// the channel will not be created and an error will be sent to the server. + /// Called when the server attempts to open a channel of unknown type. It may return `true`, + /// if the channel of unknown type should be accepted. In this case, + /// [Handler::server_channel_open_unknown] will be called soon after. If it returns `false`, + /// the channel will not be created and a rejection message will be sent to the server. #[allow(unused_variables)] - fn server_channel_handle_unknown(&self, channel: ChannelId, channel_type: &[u8]) -> bool { + async fn should_accept_unknown_server_channel( + &mut self, + id: ChannelId, + channel_type: &[u8], + ) -> bool { false } + /// Called when the server opens an unknown channel. + #[allow(unused_variables)] + async fn server_channel_open_unknown( + &mut self, + channel: Channel, + session: &mut Session, + ) -> Result<(), Self::Error> { + Ok(()) + } + /// Called when the server opens a session channel. #[allow(unused_variables)] async fn server_channel_open_session( - self, - channel: ChannelId, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + &mut self, + channel: Channel, + session: &mut Session, + ) -> Result<(), Self::Error> { + Ok(()) } /// Called when the server opens a direct tcp/ip channel. #[allow(unused_variables)] async fn server_channel_open_direct_tcpip( - self, - channel: ChannelId, + &mut self, + channel: Channel, host_to_connect: &str, port_to_connect: u32, originator_address: &str, originator_port: u32, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> Result<(), Self::Error> { + Ok(()) } /// Called when the server opens an X11 channel. #[allow(unused_variables)] async fn server_channel_open_x11( - self, + &mut self, channel: Channel, originator_address: &str, originator_port: u32, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> Result<(), Self::Error> { + Ok(()) } /// Called when the server sends us data. The `extended_code` @@ -1439,12 +1715,12 @@ pub trait Handler: Sized + Send { /// [RFC4254](https://tools.ietf.org/html/rfc4254#section-5.2). #[allow(unused_variables)] async fn data( - self, + &mut self, channel: ChannelId, data: &[u8], - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> Result<(), Self::Error> { + Ok(()) } /// Called when the server sends us data. The `extended_code` @@ -1453,13 +1729,13 @@ pub trait Handler: Sized + Send { /// [RFC4254](https://tools.ietf.org/html/rfc4254#section-5.2). #[allow(unused_variables)] async fn extended_data( - self, + &mut self, channel: ChannelId, ext: u32, data: &[u8], - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> Result<(), Self::Error> { + Ok(()) } /// The server informs this client of whether the client may @@ -1467,37 +1743,37 @@ pub trait Handler: Sized + Send { /// [RFC4254](https://tools.ietf.org/html/rfc4254#section-6.8). #[allow(unused_variables)] async fn xon_xoff( - self, + &mut self, channel: ChannelId, client_can_do: bool, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> Result<(), Self::Error> { + Ok(()) } /// The remote process has exited, with the given exit status. #[allow(unused_variables)] async fn exit_status( - self, + &mut self, channel: ChannelId, exit_status: u32, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> Result<(), Self::Error> { + Ok(()) } /// The remote process exited upon receiving a signal. #[allow(unused_variables)] async fn exit_signal( - self, + &mut self, channel: ChannelId, signal_name: Sig, core_dumped: bool, error_message: &str, lang_tag: &str, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> Result<(), Self::Error> { + Ok(()) } /// Called when the network window is adjusted, meaning that we @@ -1507,12 +1783,12 @@ pub trait Handler: Sized + Send { /// full amount of data. #[allow(unused_variables)] async fn window_adjusted( - self, + &mut self, channel: ChannelId, new_size: u32, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> Result<(), Self::Error> { + Ok(()) } /// Called when this client adjusts the network window. Return the @@ -1525,11 +1801,26 @@ pub trait Handler: Sized + Send { /// Called when the server signals success. #[allow(unused_variables)] async fn openssh_ext_host_keys_announced( - self, + &mut self, keys: Vec, - session: Session, - ) -> Result<(Self, Session), Self::Error> { + session: &mut Session, + ) -> Result<(), Self::Error> { debug!("openssh_ext_hostkeys_announced: {:?}", keys); - Ok((self, session)) + Ok(()) + } + + /// Called when the server sent a disconnect message + /// + /// If reason is an Error, this function should re-return the error so the join can also evaluate it + #[allow(unused_variables)] + async fn disconnected( + &mut self, + reason: DisconnectReason, + ) -> Result<(), Self::Error> { + debug!("disconnected: {:?}", reason); + match reason { + DisconnectReason::ReceivedDisconnect(_) => Ok(()), + DisconnectReason::Error(e) => Err(e), + } } } diff --git a/russh/src/client/session.rs b/russh/src/client/session.rs index adc87da9..26f8a761 100644 --- a/russh/src/client/session.rs +++ b/russh/src/client/session.rs @@ -1,10 +1,10 @@ -use russh_cryptovec::CryptoVec; -use russh_keys::encoding::Encoding; use log::error; +use tokio::sync::oneshot; use crate::client::Session; +use crate::keys::encoding::Encoding; use crate::session::EncryptedState; -use crate::{msg, ChannelId, Disconnect, Pty, Sig}; +use crate::{msg, ChannelId, CryptoVec, Disconnect, Pty, Sig}; impl Session { fn channel_open_generic( @@ -81,7 +81,7 @@ impl Session { pub fn channel_open_direct_streamlocal( &mut self, - socket_path: &str + socket_path: &str, ) -> Result { self.channel_open_generic(b"direct-streamlocal@openssh.com", |write| { write.extend_ssh_string(socket_path.as_bytes()); @@ -264,8 +264,23 @@ impl Session { } } - pub fn tcpip_forward(&mut self, want_reply: bool, address: &str, port: u32) { + /// Requests a TCP/IP forwarding from the server + /// + /// If `reply_channel` is not None, sets want_reply and returns the server's response via the channel, + /// [`Some`] for a success message with port, or [`None`] for failure + pub fn tcpip_forward( + &mut self, + reply_channel: Option>>, + address: &str, + port: u32, + ) { if let Some(ref mut enc) = self.common.encrypted { + let want_reply = reply_channel.is_some(); + if let Some(reply_channel) = reply_channel { + self.open_global_requests.push_back( + crate::session::GlobalRequestResponse::TcpIpForward(reply_channel), + ); + } push_packet!(enc.write, { enc.write.push(msg::GLOBAL_REQUEST); enc.write.extend_ssh_string(b"tcpip-forward"); @@ -276,8 +291,23 @@ impl Session { } } - pub fn cancel_tcpip_forward(&mut self, want_reply: bool, address: &str, port: u32) { + /// Requests cancellation of TCP/IP forwarding from the server + /// + /// If `reply_channel` is not None, sets want_reply and returns the server's response via the channel, + /// `true` for a success message, or `false` for failure + pub fn cancel_tcpip_forward( + &mut self, + reply_channel: Option>, + address: &str, + port: u32, + ) { if let Some(ref mut enc) = self.common.encrypted { + let want_reply = reply_channel.is_some(); + if let Some(reply_channel) = reply_channel { + self.open_global_requests.push_back( + crate::session::GlobalRequestResponse::CancelTcpIpForward(reply_channel), + ); + } push_packet!(enc.write, { enc.write.push(msg::GLOBAL_REQUEST); enc.write.extend_ssh_string(b"cancel-tcpip-forward"); @@ -288,6 +318,70 @@ impl Session { } } + /// Requests a UDS forwarding from the server, `socket path` being the server side socket path. + /// + /// If `reply_channel` is not None, sets want_reply and returns the server's response via the channel, + /// `true` for a success message, or `false` for failure + pub fn streamlocal_forward( + &mut self, + reply_channel: Option>, + socket_path: &str, + ) { + if let Some(ref mut enc) = self.common.encrypted { + let want_reply = reply_channel.is_some(); + if let Some(reply_channel) = reply_channel { + self.open_global_requests.push_back( + crate::session::GlobalRequestResponse::StreamLocalForward(reply_channel), + ); + } + push_packet!(enc.write, { + enc.write.push(msg::GLOBAL_REQUEST); + enc.write + .extend_ssh_string(b"streamlocal-forward@openssh.com"); + enc.write.push(want_reply as u8); + enc.write.extend_ssh_string(socket_path.as_bytes()); + }); + } + } + + /// Requests cancellation of UDS forwarding from the server + /// + /// If `reply_channel` is not None, sets want_reply and returns the server's response via the channel, + /// `true` for a success message and `false` for failure. + pub fn cancel_streamlocal_forward( + &mut self, + reply_channel: Option>, + socket_path: &str, + ) { + if let Some(ref mut enc) = self.common.encrypted { + let want_reply = reply_channel.is_some(); + if let Some(reply_channel) = reply_channel { + self.open_global_requests.push_back( + crate::session::GlobalRequestResponse::CancelStreamLocalForward(reply_channel), + ); + } + push_packet!(enc.write, { + enc.write.push(msg::GLOBAL_REQUEST); + enc.write + .extend_ssh_string(b"cancel-streamlocal-forward@openssh.com"); + enc.write.push(want_reply as u8); + enc.write.extend_ssh_string(socket_path.as_bytes()); + }); + } + } + + pub fn send_keepalive(&mut self, want_reply: bool) { + self.open_global_requests + .push_back(crate::session::GlobalRequestResponse::Keepalive); + if let Some(ref mut enc) = self.common.encrypted { + push_packet!(enc.write, { + enc.write.push(msg::GLOBAL_REQUEST); + enc.write.extend_ssh_string(b"keepalive@openssh.com"); + enc.write.push(want_reply as u8); + }); + } + } + pub fn data(&mut self, channel: ChannelId, data: CryptoVec) { if let Some(ref mut enc) = self.common.encrypted { enc.data(channel, data) @@ -352,4 +446,17 @@ impl Session { 0 } } + + /// Returns the SSH ID (Protocol Version + Software Version) the server sent when connecting + /// + /// This should contain only ASCII characters for implementations conforming to RFC4253, Section 4.2: + /// + /// > Both the 'protoversion' and 'softwareversion' strings MUST consist of + /// > printable US-ASCII characters, with the exception of whitespace + /// > characters and the minus sign (-). + /// + /// So it usually is fine to convert it to a `String` using `String::from_utf8_lossy` + pub fn remote_sshid(&self) -> &[u8] { + &self.common.remote_sshid + } } diff --git a/russh/src/compression.rs b/russh/src/compression.rs index 20aff2cc..6d739bf6 100644 --- a/russh/src/compression.rs +++ b/russh/src/compression.rs @@ -1,4 +1,6 @@ -#[derive(Debug)] +use std::convert::TryFrom; + +#[derive(Debug, Clone)] pub enum Compression { None, #[cfg(feature = "flate2")] @@ -19,10 +21,43 @@ pub enum Decompress { Zlib(flate2::Decompress), } +#[derive(Debug, PartialEq, Eq, Copy, Clone, Hash)] +pub struct Name(&'static str); +impl AsRef for Name { + fn as_ref(&self) -> &str { + self.0 + } +} + +impl TryFrom<&str> for Name { + type Error = (); + fn try_from(s: &str) -> Result { + ALL_COMPRESSION_ALGORITHMS + .iter() + .find(|x| x.0 == s) + .map(|x| **x) + .ok_or(()) + } +} + +pub const NONE: Name = Name("none"); +#[cfg(feature = "flate2")] +pub const ZLIB: Name = Name("zlib"); +#[cfg(feature = "flate2")] +pub const ZLIB_LEGACY: Name = Name("zlib@openssh.com"); + +pub const ALL_COMPRESSION_ALGORITHMS: &[&Name] = &[ + &NONE, + #[cfg(feature = "flate2")] + &ZLIB, + #[cfg(feature = "flate2")] + &ZLIB_LEGACY, +]; + #[cfg(feature = "flate2")] impl Compression { - pub fn from_string(s: &str) -> Self { - if s == "zlib" || s == "zlib@openssh.com" { + pub fn new(name: &Name) -> Self { + if name == &ZLIB || name == &ZLIB_LEGACY { Compression::Zlib } else { Compression::None @@ -56,7 +91,7 @@ impl Compression { #[cfg(not(feature = "flate2"))] impl Compression { - pub fn from_string(_: &str) -> Self { + pub fn new(_name: &Name) -> Self { Compression::None } @@ -71,7 +106,7 @@ impl Compress { &mut self, input: &'a [u8], _: &'a mut russh_cryptovec::CryptoVec, - ) -> Result<&'a [u8], Error> { + ) -> Result<&'a [u8], crate::Error> { Ok(input) } } @@ -82,7 +117,7 @@ impl Decompress { &mut self, input: &'a [u8], _: &'a mut russh_cryptovec::CryptoVec, - ) -> Result<&'a [u8], Error> { + ) -> Result<&'a [u8], crate::Error> { Ok(input) } } diff --git a/russh/src/kex/curve25519.rs b/russh/src/kex/curve25519.rs index 772b668e..c26267b4 100644 --- a/russh/src/kex/curve25519.rs +++ b/russh/src/kex/curve25519.rs @@ -3,13 +3,12 @@ use curve25519_dalek::constants::ED25519_BASEPOINT_TABLE; use curve25519_dalek::montgomery::MontgomeryPoint; use curve25519_dalek::scalar::Scalar; use log::debug; -use russh_cryptovec::CryptoVec; -use russh_keys::encoding::Encoding; use super::{compute_keys, KexAlgorithm, KexType}; +use crate::keys::encoding::Encoding; use crate::mac::{self}; use crate::session::Exchange; -use crate::{cipher, msg}; +use crate::{cipher, msg, CryptoVec}; pub struct Curve25519KexType {} @@ -103,9 +102,7 @@ impl KexAlgorithm for Curve25519Kex { } fn compute_shared_secret(&mut self, remote_pubkey_: &[u8]) -> Result<(), crate::Error> { - let local_secret = - std::mem::replace(&mut self.local_secret, None).ok_or(crate::Error::KexInit)?; - + let local_secret = self.local_secret.take().ok_or(crate::Error::KexInit)?; let mut remote_pubkey = MontgomeryPoint([0; 32]); remote_pubkey.0.clone_from_slice(remote_pubkey_); let shared = local_secret * remote_pubkey; diff --git a/russh/src/kex/dh/groups.rs b/russh/src/kex/dh/groups.rs index 56248c9c..bb1e992d 100644 --- a/russh/src/kex/dh/groups.rs +++ b/russh/src/kex/dh/groups.rs @@ -45,6 +45,38 @@ pub const DH_GROUP14: DhGroup = DhGroup { exp_size: 256, }; +pub const DH_GROUP16: DhGroup = DhGroup { + prime: hex!( + " + FFFFFFFF FFFFFFFF C90FDAA2 2168C234 C4C6628B 80DC1CD1 + 29024E08 8A67CC74 020BBEA6 3B139B22 514A0879 8E3404DD + EF9519B3 CD3A431B 302B0A6D F25F1437 4FE1356D 6D51C245 + E485B576 625E7EC6 F44C42E9 A637ED6B 0BFF5CB6 F406B7ED + EE386BFB 5A899FA5 AE9F2411 7C4B1FE6 49286651 ECE45B3D + C2007CB8 A163BF05 98DA4836 1C55D39A 69163FA8 FD24CF5F + 83655D23 DCA3AD96 1C62F356 208552BB 9ED52907 7096966D + 670C354E 4ABC9804 F1746C08 CA18217C 32905E46 2E36CE3B + E39E772C 180E8603 9B2783A2 EC07A28F B5C55DF0 6F4C52C9 + DE2BCBF6 95581718 3995497C EA956AE5 15D22618 98FA0510 + 15728E5A 8AAAC42D AD33170D 04507A33 A85521AB DF1CBA64 + ECFB8504 58DBEF0A 8AEA7157 5D060C7D B3970F85 A6E1E4C7 + ABF5AE8C DB0933D7 1E8C94E0 4A25619D CEE3D226 1AD2EE6B + F12FFA06 D98A0864 D8760273 3EC86A64 521F2B18 177B200C + BBE11757 7A615D6C 770988C0 BAD946E2 08E24FA0 74E5AB31 + 43DB5BFC E0FD108E 4B82D120 A9210801 1A723C12 A787E6D7 + 88719A10 BDBA5B26 99C32718 6AF4E23C 1A946834 B6150BDA + 2583E9CA 2AD44CE8 DBBBC2DB 04DE8EF9 2E8EFC14 1FBECAA6 + 287C5947 4E6BC05D 99B2964F A090C3A2 233BA186 515BE7ED + 1F612970 CEE2D7AF B81BDD76 2170481C D0069127 D5B05AA9 + 93B4EA98 8D8FDDC1 86FFB7DC 90A6C08F 4DF435C9 34063199 + FFFFFFFF FFFFFFFF + " + ) + .as_slice(), + generator: 2, + exp_size: 512, +}; + #[derive(Debug, PartialEq, Eq, Clone)] pub struct DH { prime_num: BigUint, @@ -70,14 +102,8 @@ impl DH { pub fn generate_private_key(&mut self, is_server: bool) -> BigUint { let q = (&self.prime_num - &BigUint::from(1u8)) / &BigUint::from(2u8); let mut rng = rand::thread_rng(); - self.private_key = rng.gen_biguint_range( - &if is_server { - 1u8.into() - } else { - 2u8.into() - }, - &q, - ); + self.private_key = + rng.gen_biguint_range(&if is_server { 1u8.into() } else { 2u8.into() }, &q); self.private_key.clone() } diff --git a/russh/src/kex/dh/mod.rs b/russh/src/kex/dh/mod.rs index 0a8b8ddf..e409348d 100644 --- a/russh/src/kex/dh/mod.rs +++ b/russh/src/kex/dh/mod.rs @@ -6,15 +6,14 @@ use digest::Digest; use groups::DH; use log::debug; use num_bigint::BigUint; -use russh_cryptovec::CryptoVec; -use russh_keys::encoding::Encoding; use sha1::Sha1; -use sha2::Sha256; +use sha2::{Sha256, Sha512}; -use self::groups::{DhGroup, DH_GROUP1, DH_GROUP14}; +use self::groups::{DhGroup, DH_GROUP1, DH_GROUP14, DH_GROUP16}; use super::{compute_keys, KexAlgorithm, KexType}; +use crate::keys::encoding::Encoding; use crate::session::Exchange; -use crate::{cipher, mac, msg}; +use crate::{cipher, mac, msg, CryptoVec}; pub struct DhGroup1Sha1KexType {} @@ -38,6 +37,14 @@ impl KexType for DhGroup14Sha256KexType { } } +pub struct DhGroup16Sha512KexType {} + +impl KexType for DhGroup16Sha512KexType { + fn make(&self) -> Box { + Box::new(DhGroupKex::::new(&DH_GROUP16)) as Box + } +} + #[doc(hidden)] pub struct DhGroupKex { dh: DH, diff --git a/russh/src/kex/ecdh_nistp.rs b/russh/src/kex/ecdh_nistp.rs new file mode 100644 index 00000000..58ad899c --- /dev/null +++ b/russh/src/kex/ecdh_nistp.rs @@ -0,0 +1,234 @@ +use std::marker::PhantomData; + +use byteorder::{BigEndian, ByteOrder}; +use elliptic_curve::ecdh::{EphemeralSecret, SharedSecret}; +use elliptic_curve::point::PointCompression; +use elliptic_curve::sec1::{FromEncodedPoint, ModulusSize, ToEncodedPoint}; +use elliptic_curve::{AffinePoint, Curve, CurveArithmetic, FieldBytesSize}; +use log::debug; +use p256::NistP256; +use p384::NistP384; +use p521::NistP521; +use sha2::{Digest, Sha256, Sha384, Sha512}; + +use crate::kex::{compute_keys, KexAlgorithm, KexType}; +use crate::keys::encoding::Encoding; +use crate::mac::{self}; +use crate::session::Exchange; +use crate::{cipher, msg, CryptoVec}; + +pub struct EcdhNistP256KexType {} + +impl KexType for EcdhNistP256KexType { + fn make(&self) -> Box { + Box::new(EcdhNistPKex:: { + local_secret: None, + shared_secret: None, + _digest: PhantomData, + }) as Box + } +} + +pub struct EcdhNistP384KexType {} + +impl KexType for EcdhNistP384KexType { + fn make(&self) -> Box { + Box::new(EcdhNistPKex:: { + local_secret: None, + shared_secret: None, + _digest: PhantomData, + }) as Box + } +} + +pub struct EcdhNistP521KexType {} + +impl KexType for EcdhNistP521KexType { + fn make(&self) -> Box { + Box::new(EcdhNistPKex:: { + local_secret: None, + shared_secret: None, + _digest: PhantomData, + }) as Box + } +} + +#[doc(hidden)] +pub struct EcdhNistPKex { + local_secret: Option>, + shared_secret: Option>, + _digest: PhantomData, +} + +impl std::fmt::Debug for EcdhNistPKex { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + f, + "Algorithm {{ local_secret: [hidden], shared_secret: [hidden] }}", + ) + } +} + +impl KexAlgorithm for EcdhNistPKex +where + C: PointCompression, + FieldBytesSize: ModulusSize, + AffinePoint: FromEncodedPoint + ToEncodedPoint, +{ + fn skip_exchange(&self) -> bool { + false + } + + #[doc(hidden)] + fn server_dh(&mut self, exchange: &mut Exchange, payload: &[u8]) -> Result<(), crate::Error> { + debug!("server_dh"); + + let client_pubkey = { + if payload.first() != Some(&msg::KEX_ECDH_INIT) { + return Err(crate::Error::Inconsistent); + } + + #[allow(clippy::indexing_slicing)] // length checked + let pubkey_len = BigEndian::read_u32(&payload[1..]) as usize; + + if payload.len() < 5 + pubkey_len { + return Err(crate::Error::Inconsistent); + } + + #[allow(clippy::indexing_slicing)] // length checked + elliptic_curve::PublicKey::::from_sec1_bytes(&payload[5..(5 + pubkey_len)]) + .map_err(|_| crate::Error::Inconsistent)? + }; + + let server_secret = + elliptic_curve::ecdh::EphemeralSecret::::random(&mut rand_core::OsRng); + let server_pubkey = server_secret.public_key(); + + // fill exchange. + exchange.server_ephemeral.clear(); + exchange + .server_ephemeral + .extend(&server_pubkey.to_sec1_bytes()); + let shared = server_secret.diffie_hellman(&client_pubkey); + self.shared_secret = Some(shared); + Ok(()) + } + + #[doc(hidden)] + fn client_dh( + &mut self, + client_ephemeral: &mut CryptoVec, + buf: &mut CryptoVec, + ) -> Result<(), crate::Error> { + let client_secret = + elliptic_curve::ecdh::EphemeralSecret::::random(&mut rand_core::OsRng); + let client_pubkey = client_secret.public_key(); + + // fill exchange. + client_ephemeral.clear(); + client_ephemeral.extend(&client_pubkey.to_sec1_bytes()); + + buf.push(msg::KEX_ECDH_INIT); + buf.extend_ssh_string(&client_pubkey.to_sec1_bytes()); + + self.local_secret = Some(client_secret); + Ok(()) + } + + fn compute_shared_secret(&mut self, remote_pubkey_: &[u8]) -> Result<(), crate::Error> { + let local_secret = self.local_secret.take().ok_or(crate::Error::KexInit)?; + let pubkey = elliptic_curve::PublicKey::::from_sec1_bytes(remote_pubkey_) + .map_err(|_| crate::Error::KexInit)?; + self.shared_secret = Some(local_secret.diffie_hellman(&pubkey)); + Ok(()) + } + + fn compute_exchange_hash( + &self, + key: &CryptoVec, + exchange: &Exchange, + buffer: &mut CryptoVec, + ) -> Result { + // Computing the exchange hash, see page 7 of RFC 5656. + buffer.clear(); + buffer.extend_ssh_string(&exchange.client_id); + buffer.extend_ssh_string(&exchange.server_id); + buffer.extend_ssh_string(&exchange.client_kex_init); + buffer.extend_ssh_string(&exchange.server_kex_init); + + buffer.extend(key); + buffer.extend_ssh_string(&exchange.client_ephemeral); + buffer.extend_ssh_string(&exchange.server_ephemeral); + + if let Some(ref shared) = self.shared_secret { + buffer.extend_ssh_mpint(shared.raw_secret_bytes()); + } + + let mut hasher = D::new(); + hasher.update(&buffer); + + let mut res = CryptoVec::new(); + res.extend(hasher.finalize().as_slice()); + Ok(res) + } + + fn compute_keys( + &self, + session_id: &CryptoVec, + exchange_hash: &CryptoVec, + cipher: cipher::Name, + remote_to_local_mac: mac::Name, + local_to_remote_mac: mac::Name, + is_server: bool, + ) -> Result { + compute_keys::( + self.shared_secret + .as_ref() + .map(|x| x.raw_secret_bytes() as &[u8]), + session_id, + exchange_hash, + cipher, + remote_to_local_mac, + local_to_remote_mac, + is_server, + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_shared_secret() { + let mut party1 = EcdhNistPKex:: { + local_secret: Some(EphemeralSecret::::random(&mut rand_core::OsRng)), + shared_secret: None, + _digest: PhantomData, + }; + let p1_pubkey = party1.local_secret.as_ref().unwrap().public_key(); + + let mut party2 = EcdhNistPKex:: { + local_secret: Some(EphemeralSecret::::random(&mut rand_core::OsRng)), + shared_secret: None, + _digest: PhantomData, + }; + let p2_pubkey = party2.local_secret.as_ref().unwrap().public_key(); + + party1 + .compute_shared_secret(&p2_pubkey.to_sec1_bytes()) + .unwrap(); + + party2 + .compute_shared_secret(&p1_pubkey.to_sec1_bytes()) + .unwrap(); + + let p1_shared_secret = party1.shared_secret.unwrap(); + let p2_shared_secret = party2.shared_secret.unwrap(); + + assert_eq!( + p1_shared_secret.raw_secret_bytes(), + p2_shared_secret.raw_secret_bytes() + ) + } +} diff --git a/russh/src/kex/mod.rs b/russh/src/kex/mod.rs index cc413d65..c01ef42d 100644 --- a/russh/src/kex/mod.rs +++ b/russh/src/kex/mod.rs @@ -17,22 +17,26 @@ //! This module exports kex algorithm names for use with [Preferred]. mod curve25519; mod dh; +mod ecdh_nistp; mod none; use std::cell::RefCell; use std::collections::HashMap; +use std::convert::TryFrom; use std::fmt::Debug; use curve25519::Curve25519KexType; -use dh::{DhGroup14Sha1KexType, DhGroup14Sha256KexType, DhGroup1Sha1KexType}; +use dh::{ + DhGroup14Sha1KexType, DhGroup14Sha256KexType, DhGroup16Sha512KexType, DhGroup1Sha1KexType, +}; use digest::Digest; +use ecdh_nistp::{EcdhNistP256KexType, EcdhNistP384KexType, EcdhNistP521KexType}; use once_cell::sync::Lazy; -use russh_cryptovec::CryptoVec; -use russh_keys::encoding::Encoding; -use crate::cipher; use crate::cipher::CIPHERS; +use crate::keys::encoding::Encoding; use crate::mac::{self, MACS}; use crate::session::Exchange; +use crate::{cipher, CryptoVec}; pub(crate) trait KexType { fn make(&self) -> Box; @@ -83,6 +87,13 @@ impl AsRef for Name { } } +impl TryFrom<&str> for Name { + type Error = (); + fn try_from(s: &str) -> Result { + KEXES.keys().find(|x| x.0 == s).map(|x| **x).ok_or(()) + } +} + /// `curve25519-sha256` pub const CURVE25519: Name = Name("curve25519-sha256"); /// `curve25519-sha256@libssh.org` @@ -93,28 +104,62 @@ pub const DH_G1_SHA1: Name = Name("diffie-hellman-group1-sha1"); pub const DH_G14_SHA1: Name = Name("diffie-hellman-group14-sha1"); /// `diffie-hellman-group14-sha256` pub const DH_G14_SHA256: Name = Name("diffie-hellman-group14-sha256"); +/// `diffie-hellman-group16-sha512` +pub const DH_G16_SHA512: Name = Name("diffie-hellman-group16-sha512"); +/// `ecdh-sha2-nistp256` +pub const ECDH_SHA2_NISTP256: Name = Name("ecdh-sha2-nistp256"); +/// `ecdh-sha2-nistp384` +pub const ECDH_SHA2_NISTP384: Name = Name("ecdh-sha2-nistp384"); +/// `ecdh-sha2-nistp521` +pub const ECDH_SHA2_NISTP521: Name = Name("ecdh-sha2-nistp521"); /// `none` pub const NONE: Name = Name("none"); /// `ext-info-c` pub const EXTENSION_SUPPORT_AS_CLIENT: Name = Name("ext-info-c"); /// `ext-info-s` pub const EXTENSION_SUPPORT_AS_SERVER: Name = Name("ext-info-s"); +/// `kex-strict-c-v00@openssh.com` +pub const EXTENSION_OPENSSH_STRICT_KEX_AS_CLIENT: Name = Name("kex-strict-c-v00@openssh.com"); +/// `kex-strict-s-v00@openssh.com` +pub const EXTENSION_OPENSSH_STRICT_KEX_AS_SERVER: Name = Name("kex-strict-s-v00@openssh.com"); const _CURVE25519: Curve25519KexType = Curve25519KexType {}; const _DH_G1_SHA1: DhGroup1Sha1KexType = DhGroup1Sha1KexType {}; const _DH_G14_SHA1: DhGroup14Sha1KexType = DhGroup14Sha1KexType {}; const _DH_G14_SHA256: DhGroup14Sha256KexType = DhGroup14Sha256KexType {}; +const _DH_G16_SHA512: DhGroup16Sha512KexType = DhGroup16Sha512KexType {}; +const _ECDH_SHA2_NISTP256: EcdhNistP256KexType = EcdhNistP256KexType {}; +const _ECDH_SHA2_NISTP384: EcdhNistP384KexType = EcdhNistP384KexType {}; +const _ECDH_SHA2_NISTP521: EcdhNistP521KexType = EcdhNistP521KexType {}; const _NONE: none::NoneKexType = none::NoneKexType {}; +pub const ALL_KEX_ALGORITHMS: &[&Name] = &[ + &CURVE25519, + &CURVE25519_PRE_RFC_8731, + &DH_G1_SHA1, + &DH_G14_SHA1, + &DH_G14_SHA256, + &DH_G16_SHA512, + &ECDH_SHA2_NISTP256, + &ECDH_SHA2_NISTP384, + &ECDH_SHA2_NISTP521, + &NONE, +]; + pub(crate) static KEXES: Lazy> = Lazy::new(|| { let mut h: HashMap<&'static Name, &(dyn KexType + Send + Sync)> = HashMap::new(); h.insert(&CURVE25519, &_CURVE25519); h.insert(&CURVE25519_PRE_RFC_8731, &_CURVE25519); + h.insert(&DH_G16_SHA512, &_DH_G16_SHA512); h.insert(&DH_G14_SHA256, &_DH_G14_SHA256); h.insert(&DH_G14_SHA1, &_DH_G14_SHA1); h.insert(&DH_G1_SHA1, &_DH_G1_SHA1); + h.insert(&ECDH_SHA2_NISTP256, &_ECDH_SHA2_NISTP256); + h.insert(&ECDH_SHA2_NISTP384, &_ECDH_SHA2_NISTP384); + h.insert(&ECDH_SHA2_NISTP521, &_ECDH_SHA2_NISTP521); h.insert(&NONE, &_NONE); + assert_eq!(ALL_KEX_ALGORITHMS.len(), h.len()); h }); diff --git a/russh/src/kex/none.rs b/russh/src/kex/none.rs index 66d903ac..5d421886 100644 --- a/russh/src/kex/none.rs +++ b/russh/src/kex/none.rs @@ -1,6 +1,5 @@ -use russh_cryptovec::CryptoVec; - use super::{KexAlgorithm, KexType}; +use crate::CryptoVec; pub struct NoneKexType {} diff --git a/russh/src/key.rs b/russh/src/key.rs index 17a28f68..b24fb176 100644 --- a/russh/src/key.rs +++ b/russh/src/key.rs @@ -12,9 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. // -use russh_cryptovec::CryptoVec; -use russh_keys::encoding::*; -use russh_keys::key::*; +use crate::keys::encoding::*; +use crate::keys::key::*; +use crate::keys::{ec, protocol}; +use crate::CryptoVec; #[doc(hidden)] pub trait PubKey { @@ -29,16 +30,14 @@ impl PubKey for PublicKey { buffer.extend_ssh_string(ED25519.0.as_bytes()); buffer.extend_ssh_string(public.as_bytes()); } - #[cfg(feature = "openssl")] PublicKey::RSA { ref key, .. } => { - #[allow(clippy::unwrap_used)] // type known - let rsa = key.0.rsa().unwrap(); - let e = rsa.e().to_vec(); - let n = rsa.n().to_vec(); - buffer.push_u32_be((4 + SSH_RSA.0.len() + mpint_len(&n) + mpint_len(&e)) as u32); - buffer.extend_ssh_string(SSH_RSA.0.as_bytes()); - buffer.extend_ssh_mpint(&e); - buffer.extend_ssh_mpint(&n); + buffer.extend_wrapped(|buffer| { + buffer.extend_ssh_string(SSH_RSA.0.as_bytes()); + buffer.extend_ssh(&protocol::RsaPublicKey::from(key)); + }); + } + PublicKey::EC { ref key } => { + write_ec_public_key(buffer, key); } } } @@ -53,15 +52,26 @@ impl PubKey for KeyPair { buffer.extend_ssh_string(ED25519.0.as_bytes()); buffer.extend_ssh_string(public.as_slice()); } - #[cfg(feature = "openssl")] KeyPair::RSA { ref key, .. } => { - let e = key.e().to_vec(); - let n = key.n().to_vec(); - buffer.push_u32_be((4 + SSH_RSA.0.len() + mpint_len(&n) + mpint_len(&e)) as u32); - buffer.extend_ssh_string(SSH_RSA.0.as_bytes()); - buffer.extend_ssh_mpint(&e); - buffer.extend_ssh_mpint(&n); + buffer.extend_wrapped(|buffer| { + buffer.extend_ssh_string(SSH_RSA.0.as_bytes()); + buffer.extend_ssh(&protocol::RsaPublicKey::from(key)); + }); + } + KeyPair::EC { ref key } => { + write_ec_public_key(buffer, &key.to_public_key()); } } } } + +pub(crate) fn write_ec_public_key(buf: &mut CryptoVec, key: &ec::PublicKey) { + let algorithm = key.algorithm().as_bytes(); + let ident = key.ident().as_bytes(); + let q = key.to_sec1_bytes(); + + buf.push_u32_be((algorithm.len() + ident.len() + q.len() + 12) as u32); + buf.extend_ssh_string(algorithm); + buf.extend_ssh_string(ident); + buf.extend_ssh_string(&q); +} diff --git a/russh/src/lib.rs b/russh/src/lib.rs index cf4cf8a9..59842843 100644 --- a/russh/src/lib.rs +++ b/russh/src/lib.rs @@ -30,11 +30,6 @@ //! * [Writing SSH clients - the `russh::client` module](client) //! * [Writing SSH servers - the `russh::server` module](server) //! -//! # Important crate features -//! -//! * RSA key support is gated behind the `openssl` feature (disabled by default). -//! * Enabling that and disabling the `rs-crypto` feature (enabled by default) will leave you with a very basic, but pure-OpenSSL RSA+AES cipherset. -//! //! # Using non-socket IO / writing tunnels //! //! The easy way to implement SSH tunnels, like `ProxyCommand` for @@ -94,22 +89,32 @@ //! messages sent through a `server::Handle` are processed when there //! is no incoming packet to read. +use std::convert::TryFrom; use std::fmt::{Debug, Display, Formatter}; -use thiserror::Error; +use log::debug; use parsing::ChannelOpenConfirmation; pub use russh_cryptovec::CryptoVec; +use thiserror::Error; + +#[cfg(test)] +mod tests; mod auth; /// Cipher names pub mod cipher; +/// Compression algorithm names +pub mod compression; /// Key exchange algorithm names pub mod kex; /// MAC algorithm names pub mod mac; -mod compression; +/// Re-export of the `russh-keys` crate. +pub use russh_keys as keys; + +mod cert; mod key; mod msg; mod negotiation; @@ -139,10 +144,7 @@ macro_rules! push_packet { } mod channels; -pub use channels::{Channel, ChannelMsg}; - -mod channel_stream; -pub use channel_stream::ChannelStream; +pub use channels::{Channel, ChannelMsg, ChannelStream}; mod parsing; mod session; @@ -153,6 +155,15 @@ pub mod server; /// Client side of this library. pub mod client; +#[derive(Debug)] +pub enum AlgorithmKind { + Kex, + Key, + Cipher, + Compression, + Mac, +} + #[derive(Debug, Error)] pub enum Error { /// The key file could not be parsed. @@ -167,25 +178,13 @@ pub enum Error { #[error("Unknown algorithm")] UnknownAlgo, - /// No common key exchange algorithm. - #[error("No common key exchange algorithm")] - NoCommonKexAlgo, - - /// No common signature algorithm. - #[error("No common key algorithm")] - NoCommonKeyAlgo, - - /// No common cipher. - #[error("No common key cipher")] - NoCommonCipher, - - /// No common compression algorithm. - #[error("No common compression algorithm")] - NoCommonCompression, - - /// No common MAC algorithm. - #[error("No common MAC algorithm")] - NoCommonMac, + /// No common algorithm found during key exchange. + #[error("No common algorithm")] + NoCommonAlgo { + kind: AlgorithmKind, + ours: Vec, + theirs: Vec, + }, /// Invalid SSH version string. #[error("invalid SSH version string")] @@ -219,6 +218,10 @@ pub enum Error { #[error("Wrong server signature")] WrongServerSig, + /// Excessive packet size. + #[error("Bad packet size: {0}")] + PacketSize(usize), + /// Message received/sent on unopened channel. #[error("Channel not open")] WrongChannel, @@ -248,6 +251,14 @@ pub enum Error { #[error("Connection timeout")] ConnectionTimeout, + /// Keepalive timeout. + #[error("Keepalive timeout")] + KeepaliveTimeout, + + /// Inactivity timeout. + #[error("Inactivity timeout")] + InactivityTimeout, + /// Missing authentication method. #[error("No authentication method")] NoAuthMethod, @@ -261,6 +272,9 @@ pub enum Error { #[error("Failed to decrypt a packet")] DecryptionError, + #[error("The request was rejected by the other party")] + RequestDenied, + #[error(transparent)] Keys(#[from] russh_keys::Error), @@ -271,9 +285,11 @@ pub enum Error { Utf8(#[from] std::str::Utf8Error), #[error(transparent)] + #[cfg(feature = "flate2")] Compress(#[from] flate2::CompressError), #[error(transparent)] + #[cfg(feature = "flate2")] Decompress(#[from] flate2::DecompressError), #[error(transparent)] @@ -285,6 +301,23 @@ pub enum Error { #[error(transparent)] Elapsed(#[from] tokio::time::error::Elapsed), + + #[error("Violation detected during strict key exchange, message {message_type} at seq no {sequence_number}")] + StrictKeyExchangeViolation { + message_type: u8, + sequence_number: usize, + }, +} + +pub(crate) fn strict_kex_violation(message_type: u8, sequence_number: usize) -> crate::Error { + debug!( + "strict kex violated at sequence no. {:?}, message type: {:?}", + sequence_number, message_type + ); + crate::Error::StrictKeyExchangeViolation { + message_type, + sequence_number, + } } #[derive(Debug, Error)] @@ -329,6 +362,7 @@ pub use auth::{AgentAuthError, MethodSet, Signer}; /// A reason for disconnection. #[allow(missing_docs)] // This should be relatively self-explanatory. +#[allow(clippy::manual_non_exhaustive)] #[derive(Debug)] pub enum Disconnect { HostNotAllowedToConnect = 1, @@ -349,6 +383,31 @@ pub enum Disconnect { IllegalUserName = 15, } +impl TryFrom for Disconnect { + type Error = crate::Error; + + fn try_from(value: u32) -> Result { + Ok(match value { + 1 => Self::HostNotAllowedToConnect, + 2 => Self::ProtocolError, + 3 => Self::KeyExchangeFailed, + 4 => Self::Reserved, + 5 => Self::MACError, + 6 => Self::CompressionError, + 7 => Self::ServiceNotAvailable, + 8 => Self::ProtocolVersionNotSupported, + 9 => Self::HostKeyNotVerifiable, + 10 => Self::ConnectionLost, + 11 => Self::ByApplication, + 12 => Self::TooManyConnections, + 13 => Self::AuthCancelledByUser, + 14 => Self::NoMoreAuthMethodsAvailable, + 15 => Self::IllegalUserName, + _ => return Err(crate::Error::Inconsistent), + }) + } +} + /// The type of signals that can be sent to a remote process. If you /// plan to use custom signals, read [the /// RFC](https://tools.ietf.org/html/rfc4254#section-6.10) to @@ -432,10 +491,16 @@ impl ChannelOpenFailure { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)] /// The identifier of a channel. pub struct ChannelId(u32); +impl From for u32 { + fn from(c: ChannelId) -> u32 { + c.0 + } +} + impl Display for ChannelId { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.0) @@ -455,6 +520,8 @@ pub(crate) struct ChannelParams { pub confirmed: bool, wants_reply: bool, pending_data: std::collections::VecDeque<(CryptoVec, Option, usize)>, + pending_eof: bool, + pending_close: bool, } impl ChannelParams { @@ -466,479 +533,12 @@ impl ChannelParams { } } -#[cfg(test)] -mod test_compress { - use std::collections::HashMap; - use std::sync::{Arc, Mutex}; - - use async_trait::async_trait; - use log::debug; - - use super::server::{Server as _, Session}; - use super::*; - use crate::server::Msg; - - #[tokio::test] - async fn compress_local_test() { - let _ = env_logger::try_init(); - - let client_key = russh_keys::key::KeyPair::generate_ed25519().unwrap(); - let mut config = server::Config::default(); - config.preferred = Preferred::COMPRESSED; - config.inactivity_timeout = None; // Some(std::time::Duration::from_secs(3)); - config.auth_rejection_time = std::time::Duration::from_secs(3); - config - .keys - .push(russh_keys::key::KeyPair::generate_ed25519().unwrap()); - let config = Arc::new(config); - let mut sh = Server { - clients: Arc::new(Mutex::new(HashMap::new())), - id: 0, - }; - - let socket = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = socket.local_addr().unwrap(); - - tokio::spawn(async move { - let (socket, _) = socket.accept().await.unwrap(); - let server = sh.new_client(socket.peer_addr().ok()); - server::run_stream(config, socket, server).await.unwrap(); - }); - - let mut config = client::Config::default(); - config.preferred = Preferred::COMPRESSED; - let config = Arc::new(config); - - dbg!(&addr); - let mut session = client::connect(config, addr, Client {}).await.unwrap(); - let authenticated = session - .authenticate_publickey( - std::env::var("USER").unwrap_or("user".to_owned()), - Arc::new(client_key), - ) - .await - .unwrap(); - assert!(authenticated); - let mut channel = session.channel_open_session().await.unwrap(); - - let data = &b"Hello, world!"[..]; - channel.data(data).await.unwrap(); - let msg = channel.wait().await.unwrap(); - match msg { - ChannelMsg::Data { data: msg_data } => { - assert_eq!(*data, *msg_data) - } - msg => panic!("Unexpected message {:?}", msg), - } - } - - #[derive(Clone)] - struct Server { - clients: Arc>>, - id: usize, - } - - impl server::Server for Server { - type Handler = Self; - fn new_client(&mut self, _: Option) -> Self { - let s = self.clone(); - self.id += 1; - s - } - } - - #[async_trait] - impl server::Handler for Server { - type Error = super::Error; - - async fn channel_open_session( - self, - channel: Channel, - session: Session, - ) -> Result<(Self, bool, Session), Self::Error> { - { - let mut clients = self.clients.lock().unwrap(); - clients.insert((self.id, channel.id()), session.handle()); - } - Ok((self, true, session)) - } - async fn auth_publickey( - self, - _: &str, - _: &russh_keys::key::PublicKey, - ) -> Result<(Self, server::Auth), Self::Error> { - debug!("auth_publickey"); - Ok((self, server::Auth::Accept)) - } - async fn data( - self, - channel: ChannelId, - data: &[u8], - mut session: Session, - ) -> Result<(Self, Session), Self::Error> { - debug!("server data = {:?}", std::str::from_utf8(data)); - session.data(channel, CryptoVec::from_slice(data)); - Ok((self, session)) - } - } - - struct Client {} - - #[async_trait] - impl client::Handler for Client { - type Error = super::Error; - - async fn check_server_key( - self, - _server_public_key: &russh_keys::key::PublicKey, - ) -> Result<(Self, bool), Self::Error> { - // println!("check_server_key: {:?}", server_public_key); - Ok((self, true)) - } - } -} - -#[cfg(test)] -use futures::Future; - -#[cfg(test)] -async fn test_session( - client_handler: CH, - server_handler: SH, - run_client: RC, - run_server: RS, -) where - RC: FnOnce(crate::client::Handle) -> F1 + Send + Sync + 'static, - RS: FnOnce(crate::server::Handle) -> F2 + Send + Sync + 'static, - F1: Future> + Send + Sync + 'static, - F2: Future + Send + Sync + 'static, - CH: crate::client::Handler + Send + Sync + 'static, - SH: crate::server::Handler + Send + Sync + 'static, -{ - use std::sync::Arc; - - use crate::*; - - let _ = env_logger::try_init(); - - let client_key = russh_keys::key::KeyPair::generate_ed25519().unwrap(); - let mut config = server::Config::default(); - config.inactivity_timeout = None; - config.auth_rejection_time = std::time::Duration::from_secs(3); - config - .keys - .push(russh_keys::key::KeyPair::generate_ed25519().unwrap()); - let config = Arc::new(config); - let socket = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = socket.local_addr().unwrap(); - - #[derive(Clone)] - struct Server {} - - let server_join = tokio::spawn(async move { - let (socket, _) = socket.accept().await.unwrap(); - - server::run_stream(config, socket, server_handler) - .await - .map_err(|_| ()) - .unwrap() - }); - - let client_join = tokio::spawn(async move { - let config = Arc::new(client::Config::default()); - let mut session = client::connect(config, addr, client_handler) - .await - .map_err(|_| ()) - .unwrap(); - let authenticated = session - .authenticate_publickey( - std::env::var("USER").unwrap_or("user".to_owned()), - Arc::new(client_key), - ) - .await - .unwrap(); - assert!(authenticated); - session - }); - - let (server_session, client_session) = tokio::join!(server_join, client_join); - let client_handle = tokio::spawn(run_client(client_session.unwrap())); - let server_handle = tokio::spawn(run_server(server_session.unwrap().handle())); - - let (server_session, client_session) = tokio::join!(server_handle, client_handle); - drop(client_session); - drop(server_session); -} - -#[cfg(test)] -mod test_channels { - use async_trait::async_trait; - use russh_cryptovec::CryptoVec; - use tokio::io::{AsyncReadExt, AsyncWriteExt}; - - use crate::server::Session; - use crate::{client, server, test_session, Channel, ChannelId, ChannelMsg}; - - #[tokio::test] - async fn test_server_channels() { - #[derive(Debug)] - struct Client {} - - #[async_trait] - impl client::Handler for Client { - type Error = crate::Error; - - async fn check_server_key( - self, - _server_public_key: &russh_keys::key::PublicKey, - ) -> Result<(Self, bool), Self::Error> { - Ok((self, true)) - } - - async fn data( - self, - channel: ChannelId, - data: &[u8], - mut session: client::Session, - ) -> Result<(Self, client::Session), Self::Error> { - assert_eq!(data, &b"hello world!"[..]); - session.data(channel, CryptoVec::from_slice(&b"hey there!"[..])); - Ok((self, session)) - } - } - - struct ServerHandle { - did_auth: Option>, - } - - impl ServerHandle { - fn get_auth_waiter(&mut self) -> tokio::sync::oneshot::Receiver<()> { - let (tx, rx) = tokio::sync::oneshot::channel(); - self.did_auth = Some(tx); - rx - } - } - - #[async_trait] - impl server::Handler for ServerHandle { - type Error = crate::Error; - - async fn auth_publickey( - self, - _: &str, - _: &russh_keys::key::PublicKey, - ) -> Result<(Self, server::Auth), Self::Error> { - Ok((self, server::Auth::Accept)) - } - async fn auth_succeeded( - mut self, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - if let Some(a) = self.did_auth.take() { - a.send(()).unwrap(); - } - Ok((self, session)) - } - } - - let mut sh = ServerHandle { did_auth: None }; - let a = sh.get_auth_waiter(); - test_session( - Client {}, - sh, - |c| async move { c }, - |s| async move { - a.await.unwrap(); - let mut ch = s.channel_open_session().await.unwrap(); - ch.data(&b"hello world!"[..]).await.unwrap(); - - let msg = ch.wait().await.unwrap(); - if let ChannelMsg::Data { data } = msg { - assert_eq!(data.as_ref(), &b"hey there!"[..]); - } else { - panic!("Unexpected message {:?}", msg); - } - s - }, - ) - .await; - } - - #[tokio::test] - async fn test_channel_streams() { - #[derive(Debug)] - struct Client {} - - #[async_trait] - impl client::Handler for Client { - type Error = crate::Error; - - async fn check_server_key( - self, - _server_public_key: &russh_keys::key::PublicKey, - ) -> Result<(Self, bool), Self::Error> { - Ok((self, true)) - } - } - - struct ServerHandle { - channel: Option>>, - } - - impl ServerHandle { - fn get_channel_waiter( - &mut self, - ) -> tokio::sync::oneshot::Receiver> { - let (tx, rx) = tokio::sync::oneshot::channel::>(); - self.channel = Some(tx); - rx - } - } - - #[async_trait] - impl server::Handler for ServerHandle { - type Error = crate::Error; - - async fn auth_publickey( - self, - _: &str, - _: &russh_keys::key::PublicKey, - ) -> Result<(Self, server::Auth), Self::Error> { - Ok((self, server::Auth::Accept)) - } - - async fn channel_open_session( - mut self, - channel: Channel, - session: server::Session, - ) -> Result<(Self, bool, Session), Self::Error> { - if let Some(a) = self.channel.take() { - println!("channel open session {:?}", a); - a.send(channel).unwrap(); - } - Ok((self, true, session)) - } - } - - let mut sh = ServerHandle { channel: None }; - let scw = sh.get_channel_waiter(); - - test_session( - Client {}, - sh, - |client| async move { - let ch = client.channel_open_session().await.unwrap(); - let mut stream = ch.into_stream(); - stream.write_all(&b"request"[..]).await.unwrap(); - - let mut buf = Vec::new(); - stream.read_buf(&mut buf).await.unwrap(); - assert_eq!(&buf, &b"response"[..]); - - stream.write_all(&b"reply"[..]).await.unwrap(); - - client - }, - |server| async move { - let channel = scw.await.unwrap(); - let mut stream = channel.into_stream(); - - let mut buf = Vec::new(); - stream.read_buf(&mut buf).await.unwrap(); - assert_eq!(&buf, &b"request"[..]); - - stream.write_all(&b"response"[..]).await.unwrap(); - - buf.clear(); - - stream.read_buf(&mut buf).await.unwrap(); - assert_eq!(&buf, &b"reply"[..]); - - server - }, - ) - .await; - } - - #[tokio::test] - async fn test_channel_objects() { - #[derive(Debug)] - struct Client {} - - #[async_trait] - impl client::Handler for Client { - type Error = crate::Error; - - async fn check_server_key( - self, - _server_public_key: &russh_keys::key::PublicKey, - ) -> Result<(Self, bool), Self::Error> { - Ok((self, true)) - } - } - - struct ServerHandle {} - - impl ServerHandle {} - - #[async_trait] - impl server::Handler for ServerHandle { - type Error = crate::Error; - - async fn auth_publickey( - self, - _: &str, - _: &russh_keys::key::PublicKey, - ) -> Result<(Self, server::Auth), Self::Error> { - Ok((self, server::Auth::Accept)) - } - - async fn channel_open_session( - self, - mut channel: Channel, - session: Session, - ) -> Result<(Self, bool, Session), Self::Error> { - tokio::spawn(async move { - while let Some(msg) = channel.wait().await { - match msg { - ChannelMsg::Data { data } => { - channel.data(&data[..]).await.unwrap(); - channel.close().await.unwrap(); - break - } - _ => {} - } - } - }); - Ok((self, true, session)) - } - } - - let sh = ServerHandle {}; - test_session( - Client {}, - sh, - |c| async move { - let mut ch = c.channel_open_session().await.unwrap(); - ch.data(&b"hello world!"[..]).await.unwrap(); - - let msg = ch.wait().await.unwrap(); - if let ChannelMsg::Data { data } = msg { - assert_eq!(data.as_ref(), &b"hey there!"[..]); - } else { - panic!("Unexpected message {:?}", msg); - } - - let msg = ch.wait().await.unwrap(); - let ChannelMsg::Close = msg else { - panic!("Unexpected message {:?}", msg); - }; - - ch.close().await.unwrap(); - c - }, - |s| async move { s }, - ) - .await; - } +pub(crate) fn future_or_pending( + val: Option, + f: impl FnOnce(T) -> F, +) -> futures::future::Either::Output>, F> { + val.map_or( + futures::future::Either::Left(futures::future::pending()), + |x| futures::future::Either::Right(f(x)), + ) } diff --git a/russh/src/mac/mod.rs b/russh/src/mac/mod.rs index 5eada31b..088f50be 100644 --- a/russh/src/mac/mod.rs +++ b/russh/src/mac/mod.rs @@ -14,6 +14,7 @@ //! //! This module exports cipher names for use with [Preferred]. use std::collections::HashMap; +use std::convert::TryFrom; use std::marker::PhantomData; use digest::typenum::{U20, U32, U64}; @@ -52,6 +53,13 @@ impl AsRef for Name { } } +impl TryFrom<&str> for Name { + type Error = (); + fn try_from(s: &str) -> Result { + MACS.keys().find(|x| x.0 == s).map(|x| **x).ok_or(()) + } +} + /// `none` pub const NONE: Name = Name("none"); /// `hmac-sha1` @@ -74,13 +82,23 @@ static _HMAC_SHA256: CryptoMacAlgorithm, U32> = CryptoMacAlgorithm(PhantomData, PhantomData); static _HMAC_SHA512: CryptoMacAlgorithm, U64> = CryptoMacAlgorithm(PhantomData, PhantomData); -static _HMAC_SHA1_ETM: CryptoEtmMacAlgorithm, U64> = +static _HMAC_SHA1_ETM: CryptoEtmMacAlgorithm, U20> = CryptoEtmMacAlgorithm(PhantomData, PhantomData); -static _HMAC_SHA256_ETM: CryptoEtmMacAlgorithm, U64> = +static _HMAC_SHA256_ETM: CryptoEtmMacAlgorithm, U32> = CryptoEtmMacAlgorithm(PhantomData, PhantomData); static _HMAC_SHA512_ETM: CryptoEtmMacAlgorithm, U64> = CryptoEtmMacAlgorithm(PhantomData, PhantomData); +pub const ALL_MAC_ALGORITHMS: &[&Name] = &[ + &NONE, + &HMAC_SHA1, + &HMAC_SHA256, + &HMAC_SHA512, + &HMAC_SHA1_ETM, + &HMAC_SHA256_ETM, + &HMAC_SHA512_ETM, +]; + pub(crate) static MACS: Lazy> = Lazy::new(|| { let mut h: HashMap<&'static Name, &(dyn MacAlgorithm + Send + Sync)> = HashMap::new(); @@ -91,5 +109,6 @@ pub(crate) static MACS: Lazy, + /// Preferred host & public key algorithms. + pub key: Cow<'static, [key::Name]>, /// Preferred symmetric ciphers. - pub cipher: &'static [cipher::Name], + pub cipher: Cow<'static, [cipher::Name]>, /// Preferred MAC algorithms. - pub mac: &'static [mac::Name], + pub mac: Cow<'static, [mac::Name]>, /// Preferred compression algorithms. - pub compression: &'static [&'static str], + pub compression: Cow<'static, [compression::Name]>, +} + +impl Preferred { + pub(crate) fn possible_host_key_algos_for_keys( + &self, + available_host_keys: &[KeyPair], + ) -> Vec { + self.key + .iter() + .filter(|n| available_host_keys.iter().any(|k| k.name() == n.0)) + .copied() + .collect::>() + } } const SAFE_KEX_ORDER: &[kex::Name] = &[ kex::CURVE25519, kex::CURVE25519_PRE_RFC_8731, + kex::DH_G16_SHA512, kex::DH_G14_SHA256, + kex::EXTENSION_SUPPORT_AS_CLIENT, + kex::EXTENSION_SUPPORT_AS_SERVER, + kex::EXTENSION_OPENSSH_STRICT_KEX_AS_CLIENT, + kex::EXTENSION_OPENSSH_STRICT_KEX_AS_SERVER, ]; const CIPHER_ORDER: &[cipher::Name] = &[ @@ -75,31 +95,35 @@ const HMAC_ORDER: &[mac::Name] = &[ mac::HMAC_SHA1, ]; -impl Preferred { - #[cfg(feature = "openssl")] - pub const DEFAULT: Preferred = Preferred { - kex: SAFE_KEX_ORDER, - key: &[key::ED25519, key::RSA_SHA2_256, key::RSA_SHA2_512], - cipher: CIPHER_ORDER, - mac: HMAC_ORDER, - compression: &["none", "zlib", "zlib@openssh.com"], - }; +const COMPRESSION_ORDER: &[compression::Name] = &[ + compression::NONE, + #[cfg(feature = "flate2")] + compression::ZLIB, + #[cfg(feature = "flate2")] + compression::ZLIB_LEGACY, +]; - #[cfg(not(feature = "openssl"))] +impl Preferred { pub const DEFAULT: Preferred = Preferred { - kex: SAFE_KEX_ORDER, - key: &[key::ED25519], - cipher: CIPHER_ORDER, - mac: HMAC_ORDER, - compression: &["none", "zlib", "zlib@openssh.com"], + kex: Cow::Borrowed(SAFE_KEX_ORDER), + key: Cow::Borrowed(&[ + key::ED25519, + key::ECDSA_SHA2_NISTP256, + key::ECDSA_SHA2_NISTP521, + key::RSA_SHA2_256, + key::RSA_SHA2_512, + ]), + cipher: Cow::Borrowed(CIPHER_ORDER), + mac: Cow::Borrowed(HMAC_ORDER), + compression: Cow::Borrowed(COMPRESSION_ORDER), }; pub const COMPRESSED: Preferred = Preferred { - kex: SAFE_KEX_ORDER, - key: &[key::ED25519, key::RSA_SHA2_256, key::RSA_SHA2_512], - cipher: CIPHER_ORDER, - mac: HMAC_ORDER, - compression: &["zlib", "zlib@openssh.com", "none"], + kex: Cow::Borrowed(SAFE_KEX_ORDER), + key: Preferred::DEFAULT.key, + cipher: Cow::Borrowed(CIPHER_ORDER), + mac: Cow::Borrowed(HMAC_ORDER), + compression: Cow::Borrowed(COMPRESSION_ORDER), }; } @@ -121,17 +145,14 @@ impl Named for () { } } -#[cfg(not(feature = "openssl"))] -use russh_keys::key::ED25519; -#[cfg(feature = "openssl")] -use russh_keys::key::{ED25519, SSH_RSA}; +use crate::keys::key::ED25519; impl Named for PublicKey { fn name(&self) -> &'static str { match self { PublicKey::Ed25519(_) => ED25519.0, - #[cfg(feature = "openssl")] - PublicKey::RSA { .. } => SSH_RSA.0, + PublicKey::RSA { ref hash, .. } => hash.name().0, + PublicKey::EC { ref key } => key.algorithm(), } } } @@ -140,111 +161,163 @@ impl Named for KeyPair { fn name(&self) -> &'static str { match self { KeyPair::Ed25519 { .. } => ED25519.0, - #[cfg(feature = "openssl")] KeyPair::RSA { ref hash, .. } => hash.name().0, + KeyPair::EC { ref key } => key.algorithm(), } } } -pub trait Select { - fn select + Copy>(a: &[S], b: &[u8]) -> Option<(bool, S)>; +pub(crate) fn parse_kex_algo_list(list: &[u8]) -> Vec<&str> { + list.split(|&x| x == b',') + .map(|x| from_utf8(x).unwrap_or_default()) + .collect() +} - fn read_kex(buffer: &[u8], pref: &Preferred) -> Result { +pub(crate) trait Select { + fn is_server() -> bool; + + fn select + Clone>( + a: &[S], + b: &[&str], + kind: AlgorithmKind, + ) -> Result<(bool, S), Error>; + + /// `available_host_keys`, if present, is used to limit the host key algorithms to the ones we have keys for. + fn read_kex( + buffer: &[u8], + pref: &Preferred, + available_host_keys: Option<&[KeyPair]>, + ) -> Result { let mut r = buffer.reader(17); + + // Key exchange + let kex_string = r.read_string()?; - let (kex_both_first, kex_algorithm) = if let Some(x) = Self::select(pref.kex, kex_string) { - x - } else { - debug!( - "Could not find common kex algorithm, other side only supports {:?}, we only support {:?}", - from_utf8(kex_string), - pref.kex - ); - return Err(Error::NoCommonKexAlgo); - }; + let (kex_both_first, kex_algorithm) = Self::select( + &pref.kex, + &parse_kex_algo_list(kex_string), + AlgorithmKind::Kex, + )?; - let key_string = r.read_string()?; - let (key_both_first, key_algorithm) = if let Some(x) = Self::select(pref.key, key_string) { - x + // Strict kex detection + + let strict_kex_requested = pref.kex.contains(if Self::is_server() { + &EXTENSION_OPENSSH_STRICT_KEX_AS_SERVER } else { - debug!( - "Could not find common key algorithm, other side only supports {:?}, we only support {:?}", - from_utf8(key_string), - pref.key - ); - return Err(Error::NoCommonKeyAlgo); + &EXTENSION_OPENSSH_STRICT_KEX_AS_CLIENT + }); + let strict_kex_provided = Self::select( + &[if Self::is_server() { + EXTENSION_OPENSSH_STRICT_KEX_AS_CLIENT + } else { + EXTENSION_OPENSSH_STRICT_KEX_AS_SERVER + }], + &parse_kex_algo_list(kex_string), + AlgorithmKind::Kex, + ) + .is_ok(); + if strict_kex_requested && strict_kex_provided { + debug!("strict kex enabled") + } + + // Host key + + let key_string: &[u8] = r.read_string()?; + let possible_host_key_algos = match available_host_keys { + Some(available_host_keys) => pref.possible_host_key_algos_for_keys(available_host_keys), + None => pref.key.iter().map(ToOwned::to_owned).collect::>(), }; + let (key_both_first, key_algorithm) = Self::select( + &possible_host_key_algos[..], + &parse_kex_algo_list(key_string), + AlgorithmKind::Key, + )?; + + // Cipher + let cipher_string = r.read_string()?; - let cipher = Self::select(pref.cipher, cipher_string); - if cipher.is_none() { - debug!( - "Could not find common cipher, other side only supports {:?}, we only support {:?}", - from_utf8(cipher_string), - pref.cipher - ); - return Err(Error::NoCommonCipher); - } + let (_cipher_both_first, cipher) = Self::select( + &pref.cipher, + &parse_kex_algo_list(cipher_string), + AlgorithmKind::Cipher, + )?; r.read_string()?; // cipher server-to-client. debug!("kex {}", line!()); - let need_mac = cipher - .and_then(|x| CIPHERS.get(&x.1)) - .map(|x| x.needs_mac()) - .unwrap_or(false); - - let client_mac = if let Some((_, m)) = Self::select(pref.mac, r.read_string()?) { - m - } else if need_mac { - return Err(Error::NoCommonMac); - } else { - mac::NONE + // MAC + + let need_mac = CIPHERS.get(&cipher).map(|x| x.needs_mac()).unwrap_or(false); + + let client_mac = match Self::select( + &pref.mac, + &parse_kex_algo_list(r.read_string()?), + AlgorithmKind::Mac, + ) { + Ok((_, m)) => m, + Err(e) => { + if need_mac { + return Err(e); + } else { + mac::NONE + } + } }; - let server_mac = if let Some((_, m)) = Self::select(pref.mac, r.read_string()?) { - m - } else if need_mac { - return Err(Error::NoCommonMac); - } else { - mac::NONE + let server_mac = match Self::select( + &pref.mac, + &parse_kex_algo_list(r.read_string()?), + AlgorithmKind::Mac, + ) { + Ok((_, m)) => m, + Err(e) => { + if need_mac { + return Err(e); + } else { + mac::NONE + } + } }; + // Compression + debug!("kex {}", line!()); // client-to-server compression. - let client_compression = - if let Some((_, c)) = Self::select(pref.compression, r.read_string()?) { - Compression::from_string(c) - } else { - return Err(Error::NoCommonCompression); - }; + let client_compression = compression::Compression::new( + &Self::select( + &pref.compression, + &parse_kex_algo_list(r.read_string()?), + AlgorithmKind::Compression, + )? + .1, + ); + debug!("kex {}", line!()); // server-to-client compression. - let server_compression = - if let Some((_, c)) = Self::select(pref.compression, r.read_string()?) { - Compression::from_string(c) - } else { - return Err(Error::NoCommonCompression); - }; + let server_compression = compression::Compression::new( + &Self::select( + &pref.compression, + &parse_kex_algo_list(r.read_string()?), + AlgorithmKind::Compression, + )? + .1, + ); debug!("client_compression = {:?}", client_compression); r.read_string()?; // languages client-to-server r.read_string()?; // languages server-to-client let follows = r.read_byte()? != 0; - match (cipher, follows) { - (Some((_, cipher)), fol) => { - Ok(Names { - kex: kex_algorithm, - key: key_algorithm, - cipher, - client_mac, - server_mac, - client_compression, - server_compression, - // Ignore the next packet if (1) it follows and (2) it's not the correct guess. - ignore_guessed: fol && !(kex_both_first && key_both_first), - }) - } - _ => Err(Error::KexInit), - } + Ok(Names { + kex: kex_algorithm, + key: key_algorithm, + cipher, + client_mac, + server_mac, + client_compression, + server_compression, + // Ignore the next packet if (1) it follows and (2) it's not the correct guess. + ignore_guessed: follows && !(kex_both_first && key_both_first), + strict_kex: strict_kex_requested && strict_kex_provided, + }) } } @@ -252,36 +325,64 @@ pub struct Server; pub struct Client; impl Select for Server { - fn select + Copy>(server_list: &[S], client_list: &[u8]) -> Option<(bool, S)> { + fn is_server() -> bool { + true + } + + fn select + Clone>( + server_list: &[S], + client_list: &[&str], + kind: AlgorithmKind, + ) -> Result<(bool, S), Error> { let mut both_first_choice = true; - for c in client_list.split(|&x| x == b',') { - for &s in server_list { - if c == s.as_ref().as_bytes() { - return Some((both_first_choice, s)); + for c in client_list { + for s in server_list { + if c == &s.as_ref() { + return Ok((both_first_choice, s.clone())); } both_first_choice = false } } - None + Err(Error::NoCommonAlgo { + kind, + ours: server_list.iter().map(|x| x.as_ref().to_owned()).collect(), + theirs: client_list.iter().map(|x| (*x).to_owned()).collect(), + }) } } impl Select for Client { - fn select + Copy>(client_list: &[S], server_list: &[u8]) -> Option<(bool, S)> { + fn is_server() -> bool { + false + } + + fn select + Clone>( + client_list: &[S], + server_list: &[&str], + kind: AlgorithmKind, + ) -> Result<(bool, S), Error> { let mut both_first_choice = true; - for &c in client_list { - for s in server_list.split(|&x| x == b',') { - if s == c.as_ref().as_bytes() { - return Some((both_first_choice, c)); + for c in client_list { + for s in server_list { + if s == &c.as_ref() { + return Ok((both_first_choice, c.clone())); } both_first_choice = false } } - None + Err(Error::NoCommonAlgo { + kind, + ours: client_list.iter().map(|x| x.as_ref().to_owned()).collect(), + theirs: server_list.iter().map(|x| (*x).to_owned()).collect(), + }) } } -pub fn write_kex(prefs: &Preferred, buf: &mut CryptoVec, as_server: bool) -> Result<(), Error> { +pub fn write_kex( + prefs: &Preferred, + buf: &mut CryptoVec, + server_config: Option<&Config>, +) -> Result<(), Error> { // buf.clear(); buf.push(msg::KEXINIT); @@ -290,14 +391,31 @@ pub fn write_kex(prefs: &Preferred, buf: &mut CryptoVec, as_server: bool) -> Res buf.extend(&cookie); // cookie buf.extend_list(prefs.kex.iter().filter(|k| { - **k != if as_server { - crate::kex::EXTENSION_SUPPORT_AS_CLIENT + !(if server_config.is_some() { + [ + crate::kex::EXTENSION_SUPPORT_AS_CLIENT, + crate::kex::EXTENSION_OPENSSH_STRICT_KEX_AS_CLIENT, + ] } else { - crate::kex::EXTENSION_SUPPORT_AS_SERVER - } + [ + crate::kex::EXTENSION_SUPPORT_AS_SERVER, + crate::kex::EXTENSION_OPENSSH_STRICT_KEX_AS_SERVER, + ] + }) + .contains(*k) })); // kex algo - buf.extend_list(prefs.key.iter()); + if let Some(server_config) = server_config { + // Only advertise host key algorithms that we have keys for. + buf.extend_list( + prefs + .key + .iter() + .filter(|name| server_config.keys.iter().any(|k| k.name() == name.0)), + ); + } else { + buf.extend_list(prefs.key.iter()); + } buf.extend_list(prefs.cipher.iter()); // cipher client to server buf.extend_list(prefs.cipher.iter()); // cipher server to client diff --git a/russh/src/parsing.rs b/russh/src/parsing.rs index 77f84e07..fe80c974 100644 --- a/russh/src/parsing.rs +++ b/russh/src/parsing.rs @@ -1,7 +1,5 @@ -use russh_cryptovec::CryptoVec; -use russh_keys::encoding::{Encoding, Position}; - -use crate::msg; +use crate::keys::encoding::{Encoding, Position}; +use crate::{msg, CryptoVec}; #[derive(Debug)] pub struct OpenChannelMessage { @@ -34,6 +32,9 @@ impl OpenChannelMessage { } b"direct-tcpip" => ChannelType::DirectTcpip(TcpChannelInfo::new(r)?), b"forwarded-tcpip" => ChannelType::ForwardedTcpIp(TcpChannelInfo::new(r)?), + b"forwarded-streamlocal@openssh.com" => { + ChannelType::ForwardedStreamLocal(StreamLocalChannelInfo::new(r)?) + } b"auth-agent@openssh.com" => ChannelType::AgentForward, t => ChannelType::Unknown { typ: t.to_vec() }, }; @@ -93,6 +94,7 @@ pub enum ChannelType { }, DirectTcpip(TcpChannelInfo), ForwardedTcpIp(TcpChannelInfo), + ForwardedStreamLocal(StreamLocalChannelInfo), AgentForward, Unknown { typ: Vec, @@ -107,6 +109,21 @@ pub struct TcpChannelInfo { pub originator_port: u32, } +#[derive(Debug)] +pub struct StreamLocalChannelInfo { + pub socket_path: String, +} + +impl StreamLocalChannelInfo { + fn new(r: &mut Position) -> Result { + let socket_path = std::str::from_utf8(r.read_string().map_err(crate::Error::from)?) + .map_err(crate::Error::from)? + .to_owned(); + + Ok(Self { socket_path }) + } +} + impl TcpChannelInfo { fn new(r: &mut Position) -> Result { let host_to_connect = std::str::from_utf8(r.read_string().map_err(crate::Error::from)?) diff --git a/russh/src/server/encrypted.rs b/russh/src/server/encrypted.rs index abfa2555..9f790610 100644 --- a/russh/src/server/encrypted.rs +++ b/russh/src/server/encrypted.rs @@ -18,25 +18,25 @@ use auth::*; use byteorder::{BigEndian, ByteOrder}; use log::{debug, error, info, trace, warn}; use negotiation::Select; -use russh_keys::encoding::{Encoding, Position, Reader}; -use russh_keys::key; -use russh_keys::key::Verify; -use tokio::sync::mpsc::unbounded_channel; use tokio::time::Instant; use {msg, negotiation}; use super::super::*; use super::*; +use crate::keys::encoding::{Encoding, Position, Reader}; +use crate::keys::key; +use crate::keys::key::Verify; use crate::msg::SSH_OPEN_ADMINISTRATIVELY_PROHIBITED; use crate::parsing::{ChannelOpenConfirmation, ChannelType, OpenChannelMessage}; impl Session { /// Returns false iff a request was rejected. pub(crate) async fn server_read_encrypted( - mut self, - mut handler: H, + &mut self, + handler: &mut H, + seqn: &mut Wrapping, buf: &[u8], - ) -> Result<(H, Self), H::Error> { + ) -> Result<(), H::Error> { #[allow(clippy::indexing_slicing)] // length checked { trace!( @@ -47,7 +47,7 @@ impl Session { // Either this packet is a KEXINIT, in which case we start a key re-exchange. #[allow(clippy::unwrap_used)] - let mut enc = self.common.encrypted.as_mut().unwrap(); + let enc = self.common.encrypted.as_mut().unwrap(); if buf.first() == Some(&msg::KEXINIT) { debug!("Received rekeying request"); // If we're not currently rekeying, but `buf` is a rekey request @@ -61,7 +61,11 @@ impl Session { } else if let Some(exchange) = enc.exchange.take() { let kexinit = KexInit::received_rekey( exchange, - negotiation::Server::read_kex(buf, &self.common.config.as_ref().preferred)?, + negotiation::Server::read_kex( + buf, + &self.common.config.as_ref().preferred, + Some(&self.common.config.as_ref().keys), + )?, &enc.session_id, ); enc.rekey = Some(kexinit.server_parse( @@ -71,8 +75,11 @@ impl Session { &mut self.common.write_buffer, )?); } + if let Some(Kex::Dh(KexDh { ref names, .. })) = enc.rekey { + self.common.strict_kex = self.common.strict_kex || names.strict_kex; + } self.flush()?; - return Ok((handler, self)); + return Ok(()); } match enc.rekey.take() { @@ -83,8 +90,12 @@ impl Session { buf, &mut self.common.write_buffer, )?); + if let Some(Kex::Keys(_)) = enc.rekey { + // just sent NEWKEYS + self.common.maybe_reset_seqn(); + } self.flush()?; - return Ok((handler, self)); + return Ok(()); } Some(Kex::Keys(newkeys)) => { if buf.first() != Some(&msg::NEWKEYS) { @@ -97,24 +108,32 @@ impl Session { enc.flush_all_pending(); let mut pending = std::mem::take(&mut self.pending_reads); for p in pending.drain(..) { - let (h, s) = self.process_packet(handler, &p).await?; - handler = h; - self = s; + self.process_packet(handler, &p).await?; } self.pending_reads = pending; self.pending_len = 0; self.common.newkeys(newkeys); + if self.common.strict_kex { + *seqn = Wrapping(0); + } self.flush()?; - return Ok((handler, self)); + return Ok(()); } Some(Kex::Init(k)) => { + if let Some(ref algo) = k.algo { + if self.common.strict_kex && !algo.strict_kex { + return Err(strict_kex_violation(msg::KEXINIT, 0).into()); + } + } + enc.rekey = Some(Kex::Init(k)); + self.pending_len += buf.len() as u32; if self.pending_len > 2 * self.target_window_size { return Err(Error::Pending.into()); } self.pending_reads.push(CryptoVec::from_slice(buf)); - return Ok((handler, self)); + return Ok(()); } rek => { trace!("rek = {:?}", rek); @@ -125,10 +144,10 @@ impl Session { } async fn process_packet( - mut self, - mut handler: H, + &mut self, + handler: &mut H, buf: &[u8], - ) -> Result<(H, Self), H::Error> { + ) -> Result<(), H::Error> { let rejection_wait_until = tokio::time::Instant::now() + self.common.config.auth_rejection_time; let initial_none_rejection_wait_until = if self.common.auth_attempts == 0 { @@ -143,7 +162,7 @@ impl Session { }; #[allow(clippy::unwrap_used)] - let mut enc = self.common.encrypted.as_mut().unwrap(); + let enc = self.common.encrypted.as_mut().unwrap(); // If we've successfully read a packet. match enc.state { EncryptedState::WaitingAuthServiceRequest { @@ -161,47 +180,44 @@ impl Session { *accepted = true; enc.state = EncryptedState::WaitingAuthRequest(auth_request); } - Ok((handler, self)) + Ok(()) } EncryptedState::WaitingAuthRequest(_) if buf.first() == Some(&msg::USERAUTH_REQUEST) => { - handler = enc - .server_read_auth_request( - rejection_wait_until, - initial_none_rejection_wait_until, - handler, - buf, - &mut self.common.auth_user, - ) - .await?; + enc.server_read_auth_request( + rejection_wait_until, + initial_none_rejection_wait_until, + handler, + buf, + &mut self.common.auth_user, + ) + .await?; self.common.auth_attempts += 1; if let EncryptedState::InitCompression = enc.state { enc.client_compression.init_decompress(&mut enc.decompress); - handler.auth_succeeded(self).await - } else { - Ok((handler, self)) + handler.auth_succeeded(self).await?; } + Ok(()) } EncryptedState::WaitingAuthRequest(ref mut auth) if buf.first() == Some(&msg::USERAUTH_INFO_RESPONSE) => { - let (h, resp) = read_userauth_info_response( + let resp = read_userauth_info_response( rejection_wait_until, handler, &mut enc.write, auth, - &mut self.common.auth_user, + &self.common.auth_user, buf, ) .await?; - handler = h; if resp { enc.state = EncryptedState::InitCompression; enc.client_compression.init_decompress(&mut enc.decompress); handler.auth_succeeded(self).await } else { - Ok((handler, self)) + Ok(()) } } EncryptedState::InitCompression => { @@ -210,7 +226,7 @@ impl Session { self.server_read_authenticated(handler, buf).await } EncryptedState::Authenticated => self.server_read_authenticated(handler, buf).await, - _ => Ok((handler, self)), + _ => Ok(()), } } } @@ -247,10 +263,10 @@ impl Encrypted { &mut self, mut until: Instant, initial_auth_until: Instant, - mut handler: H, + handler: &mut H, buf: &[u8], auth_user: &mut String, - ) -> Result { + ) -> Result<(), H::Error> { // https://tools.ietf.org/html/rfc4252#section-5 let mut r = buf.reader(1); let user = r.read_string().map_err(crate::Error::from)?; @@ -277,17 +293,24 @@ impl Encrypted { r.read_byte().map_err(crate::Error::from)?; let password = r.read_string().map_err(crate::Error::from)?; let password = std::str::from_utf8(password).map_err(crate::Error::from)?; - let (handler, auth) = handler.auth_password(user, password).await?; + let auth = handler.auth_password(user, password).await?; if let Auth::Accept = auth { server_auth_request_success(&mut self.write); self.state = EncryptedState::InitCompression; } else { auth_user.clear(); - auth_request.methods -= MethodSet::PASSWORD; + if let Auth::Reject { + proceed_with_methods: Some(proceed_with_methods), + } = auth + { + auth_request.methods = proceed_with_methods; + } else { + auth_request.methods -= MethodSet::PASSWORD; + } auth_request.partial_success = false; reject_auth_request(until, &mut self.write, auth_request).await; } - Ok(handler) + Ok(()) } else if method == b"publickey" { self.server_read_auth_request_pk(until, handler, buf, auth_user, user, r) .await @@ -303,17 +326,24 @@ impl Encrypted { until = initial_auth_until } - let (handler, auth) = handler.auth_none(user).await?; + let auth = handler.auth_none(user).await?; if let Auth::Accept = auth { server_auth_request_success(&mut self.write); self.state = EncryptedState::InitCompression; } else { auth_user.clear(); - auth_request.methods -= MethodSet::NONE; + if let Auth::Reject { + proceed_with_methods: Some(proceed_with_methods), + } = auth + { + auth_request.methods = proceed_with_methods; + } else { + auth_request.methods -= MethodSet::NONE; + } auth_request.partial_success = false; reject_auth_request(until, &mut self.write, auth_request).await; } - Ok(handler) + Ok(()) } else if method == b"keyboard-interactive" { let auth_request = if let EncryptedState::WaitingAuthRequest(ref mut a) = self.state { @@ -330,14 +360,13 @@ impl Encrypted { auth_request.current = Some(CurrentRequest::KeyboardInteractive { submethods: submethods.to_string(), }); - let (h, auth) = handler + let auth = handler .auth_keyboard_interactive(user, submethods, None) .await?; - handler = h; if reply_userauth_info_response(until, auth_request, &mut self.write, auth).await? { self.state = EncryptedState::InitCompression } - Ok(handler) + Ok(()) } else { // Other methods of the base specification are insecure or optional. let auth_request = if let EncryptedState::WaitingAuthRequest(ref mut a) = self.state @@ -347,7 +376,7 @@ impl Encrypted { unreachable!() }; reject_auth_request(until, &mut self.write, auth_request).await; - Ok(handler) + Ok(()) } } else { // Unknown service @@ -364,12 +393,12 @@ impl Encrypted { async fn server_read_auth_request_pk( &mut self, until: Instant, - mut handler: H, + handler: &mut H, buf: &[u8], auth_user: &mut String, user: &str, mut r: Position<'_>, - ) -> Result { + ) -> Result<(), H::Error> { let auth_request = if let EncryptedState::WaitingAuthRequest(ref mut a) = self.state { a } else { @@ -397,7 +426,9 @@ impl Encrypted { debug!("signature = {:?}", signature); let mut s = signature.reader(0); let algo_ = s.read_string().map_err(crate::Error::from)?; - pubkey.set_algorithm(algo_); + if let Some(hash) = key::SignatureHash::from_rsa_hostkey_algo(algo_) { + pubkey.set_algorithm(hash); + } debug!("algo_: {:?}", algo_); let sig = s.read_string().map_err(crate::Error::from)?; #[allow(clippy::indexing_slicing)] // length checked @@ -408,15 +439,14 @@ impl Encrypted { } else if auth_user.is_empty() { auth_user.clear(); auth_user.push_str(user); - let (h, auth) = handler.auth_publickey(user, &pubkey).await?; - handler = h; + let auth = handler.auth_publickey_offered(user, &pubkey).await?; auth == Auth::Accept } else { false }; if is_valid { let session_id = self.session_id.as_ref(); - #[allow(clippy::blocks_in_if_conditions)] // length checked + #[allow(clippy::blocks_in_conditions)] if SIGNATURE_BUFFER.with(|buf| { let mut buf = buf.borrow_mut(); buf.clear(); @@ -426,8 +456,22 @@ impl Encrypted { pubkey.verify_client_auth(&buf, sig) }) { debug!("signature verified"); - server_auth_request_success(&mut self.write); - self.state = EncryptedState::InitCompression; + let auth = handler.auth_publickey(user, &pubkey).await?; + + if auth == Auth::Accept { + server_auth_request_success(&mut self.write); + self.state = EncryptedState::InitCompression; + } else { + if let Auth::Reject { + proceed_with_methods: Some(proceed_with_methods), + } = auth + { + auth_request.methods = proceed_with_methods; + } + auth_request.partial_success = false; + auth_user.clear(); + reject_auth_request(until, &mut self.write, auth_request).await; + } } else { debug!("signature wrong"); reject_auth_request(until, &mut self.write, auth_request).await; @@ -435,12 +479,11 @@ impl Encrypted { } else { reject_auth_request(until, &mut self.write, auth_request).await; } - Ok(handler) + Ok(()) } else { auth_user.clear(); auth_user.push_str(user); - let (h, auth) = handler.auth_publickey(user, &pubkey).await?; - handler = h; + let auth = handler.auth_publickey_offered(user, &pubkey).await?; match auth { Auth::Accept => { let mut public_key = CryptoVec::new(); @@ -473,17 +516,14 @@ impl Encrypted { reject_auth_request(until, &mut self.write, auth_request).await; } } - Ok(handler) + Ok(()) } } - Err(e) => { - if let russh_keys::Error::CouldNotReadKey = e { - reject_auth_request(until, &mut self.write, auth_request).await; - Ok(handler) - } else { - Err(crate::Error::from(e).into()) - } + Err(russh_keys::Error::CouldNotReadKey) | Err(russh_keys::Error::KeyIsCorrupt) => { + reject_auth_request(until, &mut self.write, auth_request).await; + Ok(()) } + Err(e) => Err(crate::Error::from(e).into()), } } } @@ -513,27 +553,26 @@ fn server_auth_request_success(buffer: &mut CryptoVec) { async fn read_userauth_info_response( until: Instant, - mut handler: H, + handler: &mut H, write: &mut CryptoVec, auth_request: &mut AuthRequest, - user: &mut str, + user: &str, b: &[u8], -) -> Result<(H, bool), H::Error> { +) -> Result { if let Some(CurrentRequest::KeyboardInteractive { ref submethods }) = auth_request.current { let mut r = b.reader(1); let n = r.read_u32().map_err(crate::Error::from)?; let response = Response { pos: r, n }; - let (h, auth) = handler + let auth = handler .auth_keyboard_interactive(user, submethods, Some(response)) .await?; - handler = h; let resp = reply_userauth_info_response(until, auth_request, write, auth) .await .map_err(H::Error::from)?; - Ok((handler, resp)) + Ok(resp) } else { reject_auth_request(until, write, auth_request).await; - Ok((handler, false)) + Ok(false) } } @@ -582,10 +621,10 @@ async fn reply_userauth_info_response( impl Session { async fn server_read_authenticated( - mut self, - mut handler: H, + &mut self, + handler: &mut H, buf: &[u8], - ) -> Result<(H, Self), H::Error> { + ) -> Result<(), H::Error> { #[allow(clippy::indexing_slicing)] // length checked { trace!( @@ -597,7 +636,7 @@ impl Session { Some(&msg::CHANNEL_OPEN) => self .server_handle_channel_open(handler, buf) .await - .map(|(h, _, s)| (h, s)), + .map(|_| ()), Some(&msg::CHANNEL_CLOSE) => { let mut r = buf.reader(1); let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?); @@ -676,6 +715,8 @@ impl Session { enc.flush_pending(channel_num); } if let Some(chan) = self.channels.get(&channel_num) { + *chan.window_size().lock().await = new_size; + chan.send(ChannelMsg::WindowAdjusted { new_size }) .unwrap_or(()) } @@ -858,15 +899,14 @@ impl Session { let _ = chan.send(ChannelMsg::AgentForward { want_reply: true }); } debug!("handler.agent_request {:?}", channel_num); - let response; - (handler, response, self) = - handler.agent_request(channel_num, self).await?; + + let response = handler.agent_request(channel_num, self).await?; if response { self.request_success() } else { self.request_failure() } - Ok((handler, self)) + Ok(()) } b"exec" => { let req = r.read_string().map_err(crate::Error::from)?; @@ -934,7 +974,7 @@ impl Session { x => { warn!("unknown channel request {}", String::from_utf8_lossy(x)); self.channel_failure(channel_num); - Ok((handler, self)) + Ok(()) } } } @@ -950,14 +990,14 @@ impl Session { let port = r.read_u32().map_err(crate::Error::from)?; debug!("handler.tcpip_forward {:?} {:?}", address, port); let mut returned_port = port; - let (h, result, mut s) = handler + let result = handler .tcpip_forward(address, &mut returned_port, self) .await?; - if let Some(ref mut enc) = s.common.encrypted { + if let Some(ref mut enc) = self.common.encrypted { if result { push_packet!(enc.write, { enc.write.push(msg::REQUEST_SUCCESS); - if s.common.wants_reply && port == 0 && returned_port != 0 { + if self.common.wants_reply && port == 0 && returned_port != 0 { enc.write.push_u32_be(returned_port); } }) @@ -965,7 +1005,7 @@ impl Session { push_packet!(enc.write, enc.write.push(msg::REQUEST_FAILURE)) } } - Ok((h, s)) + Ok(()) } b"cancel-tcpip-forward" => { let address = @@ -973,16 +1013,49 @@ impl Session { .map_err(crate::Error::from)?; let port = r.read_u32().map_err(crate::Error::from)?; debug!("handler.cancel_tcpip_forward {:?} {:?}", address, port); - let (h, result, mut s) = - handler.cancel_tcpip_forward(address, port, self).await?; - if let Some(ref mut enc) = s.common.encrypted { + let result = handler.cancel_tcpip_forward(address, port, self).await?; + if let Some(ref mut enc) = self.common.encrypted { if result { push_packet!(enc.write, enc.write.push(msg::REQUEST_SUCCESS)) } else { push_packet!(enc.write, enc.write.push(msg::REQUEST_FAILURE)) } } - Ok((h, s)) + Ok(()) + } + b"streamlocal-forward@openssh.com" => { + let server_socket_path = + std::str::from_utf8(r.read_string().map_err(crate::Error::from)?) + .map_err(crate::Error::from)?; + debug!("handler.streamlocal_forward {:?}", server_socket_path); + let result = handler + .streamlocal_forward(server_socket_path, self) + .await?; + if let Some(ref mut enc) = self.common.encrypted { + if result { + push_packet!(enc.write, enc.write.push(msg::REQUEST_SUCCESS)) + } else { + push_packet!(enc.write, enc.write.push(msg::REQUEST_FAILURE)) + } + } + Ok(()) + } + b"cancel-streamlocal-forward@openssh.com" => { + let socket_path = + std::str::from_utf8(r.read_string().map_err(crate::Error::from)?) + .map_err(crate::Error::from)?; + debug!("handler.cancel_streamlocal_forward {:?}", socket_path); + let result = handler + .cancel_streamlocal_forward(socket_path, self) + .await?; + if let Some(ref mut enc) = self.common.encrypted { + if result { + push_packet!(enc.write, enc.write.push(msg::REQUEST_SUCCESS)) + } else { + push_packet!(enc.write, enc.write.push(msg::REQUEST_FAILURE)) + } + } + Ok(()) } _ => { if let Some(ref mut enc) = self.common.encrypted { @@ -990,7 +1063,7 @@ impl Session { enc.write.push(msg::REQUEST_FAILURE); }); } - Ok((handler, self)) + Ok(()) } } } @@ -1021,20 +1094,69 @@ impl Session { .map_err(|_| crate::Error::SendError)?; } - Ok((handler, self)) + Ok(()) + } + Some(&msg::REQUEST_SUCCESS) => { + trace!("Global Request Success"); + match self.open_global_requests.pop_front() { + Some(GlobalRequestResponse::Keepalive) => { + // ignore keepalives + } + Some(GlobalRequestResponse::TcpIpForward(return_channel)) => { + let result = if buf.len() == 1 { + // If a specific port was requested, the reply has no data + Some(0) + } else { + let mut r = buf.reader(1); + match r.read_u32() { + Ok(port) => Some(port), + Err(e) => { + error!("Error parsing port for TcpIpForward request: {e:?}"); + None + } + } + }; + let _ = return_channel.send(result); + } + Some(GlobalRequestResponse::CancelTcpIpForward(return_channel)) => { + let _ = return_channel.send(true); + } + _ => { + error!("Received global request failure for unknown request!") + } + } + Ok(()) + } + Some(&msg::REQUEST_FAILURE) => { + trace!("global request failure"); + match self.open_global_requests.pop_front() { + Some(GlobalRequestResponse::Keepalive) => { + // ignore keepalives + } + Some(GlobalRequestResponse::TcpIpForward(return_channel)) => { + let _ = return_channel.send(None); + } + Some(GlobalRequestResponse::CancelTcpIpForward(return_channel)) => { + let _ = return_channel.send(false); + } + _ => { + error!("Received global request failure for unknown request!") + } + } + Ok(()) } m => { debug!("unknown message received: {:?}", m); - Ok((handler, self)) + Ok(()) } } } async fn server_handle_channel_open( - mut self, - handler: H, + &mut self, + handler: &mut H, buf: &[u8], - ) -> Result<(H, bool, Self), H::Error> { + ) -> Result { let mut r = buf.reader(1); let msg = OpenChannelMessage::parse(&mut r)?; @@ -1056,23 +1178,23 @@ impl Session { confirmed: true, wants_reply: false, pending_data: std::collections::VecDeque::new(), + pending_eof: false, + pending_close: false, }; - let (sender, receiver) = unbounded_channel(); - let channel = Channel { - id: sender_channel, - sender: self.sender.sender.clone(), - receiver, - max_packet_size: channel_params.recipient_maximum_packet_size, - window_size: channel_params.recipient_window_size, - }; + let (channel, reference) = Channel::new( + sender_channel, + self.sender.sender.clone(), + channel_params.recipient_maximum_packet_size, + channel_params.recipient_window_size, + ); match &msg.typ { ChannelType::Session => { let mut result = handler.channel_open_session(channel, self).await; - if let Ok((_, allowed, s)) = &mut result { - s.channels.insert(sender_channel, sender); - s.finalize_channel_open(&msg, channel_params, *allowed); + if let Ok(allowed) = &mut result { + self.channels.insert(sender_channel, reference); + self.finalize_channel_open(&msg, channel_params, *allowed); } result } @@ -1083,9 +1205,9 @@ impl Session { let mut result = handler .channel_open_x11(channel, originator_address, *originator_port, self) .await; - if let Ok((_, allowed, s)) = &mut result { - s.channels.insert(sender_channel, sender); - s.finalize_channel_open(&msg, channel_params, *allowed); + if let Ok(allowed) = &mut result { + self.channels.insert(sender_channel, reference); + self.finalize_channel_open(&msg, channel_params, *allowed); } result } @@ -1100,9 +1222,9 @@ impl Session { self, ) .await; - if let Ok((_, allowed, s)) = &mut result { - s.channels.insert(sender_channel, sender); - s.finalize_channel_open(&msg, channel_params, *allowed); + if let Ok(allowed) = &mut result { + self.channels.insert(sender_channel, reference); + self.finalize_channel_open(&msg, channel_params, *allowed); } result } @@ -1117,12 +1239,22 @@ impl Session { self, ) .await; - if let Ok((_, allowed, s)) = &mut result { - s.channels.insert(sender_channel, sender); - s.finalize_channel_open(&msg, channel_params, *allowed); + if let Ok(allowed) = &mut result { + self.channels.insert(sender_channel, reference); + self.finalize_channel_open(&msg, channel_params, *allowed); } result } + ChannelType::ForwardedStreamLocal(_) => { + if let Some(ref mut enc) = self.common.encrypted { + msg.fail( + &mut enc.write, + msg::SSH_OPEN_ADMINISTRATIVELY_PROHIBITED, + b"Unsupported channel type", + ); + } + Ok(false) + } ChannelType::AgentForward => { if let Some(ref mut enc) = self.common.encrypted { msg.fail( @@ -1131,14 +1263,14 @@ impl Session { b"Unsupported channel type", ); } - Ok((handler, false, self)) + Ok(false) } ChannelType::Unknown { typ } => { debug!("unknown channel type: {}", String::from_utf8_lossy(typ)); if let Some(ref mut enc) = self.common.encrypted { msg.unknown_type(&mut enc.write); } - Ok((handler, false, self)) + Ok(false) } } } diff --git a/russh/src/server/kex.rs b/russh/src/server/kex.rs index 07a83ecc..6961e01f 100644 --- a/russh/src/server/kex.rs +++ b/russh/src/server/kex.rs @@ -1,12 +1,12 @@ use std::cell::RefCell; -use russh_keys::encoding::{Encoding, Reader}; use log::debug; use super::*; use crate::cipher::SealingKey; use crate::kex::KEXES; use crate::key::PubKey; +use crate::keys::encoding::{Encoding, Reader}; use crate::negotiation::Select; use crate::{msg, negotiation}; @@ -26,7 +26,7 @@ impl KexInit { let algo = { // read algorithms from packet. self.exchange.client_kex_init.extend(buf); - super::negotiation::Server::read_kex(buf, &config.preferred)? + super::negotiation::Server::read_kex(buf, &config.preferred, Some(&config.keys))? }; if !self.sent { self.server_write(config, cipher, write_buffer)? @@ -44,6 +44,7 @@ impl KexInit { session_id: self.session_id, }) } else { + debug!("unknown key {:?}", algo.key); return Err(Error::UnknownKey); }; @@ -60,7 +61,11 @@ impl KexInit { write_buffer: &mut SSHBuffer, ) -> Result<(), Error> { self.exchange.server_kex_init.clear(); - negotiation::write_kex(&config.preferred, &mut self.exchange.server_kex_init, true)?; + negotiation::write_kex( + &config.preferred, + &mut self.exchange.server_kex_init, + Some(config), + )?; debug!("server kex init: {:?}", &self.exchange.server_kex_init[..]); self.sent = true; cipher.write(&self.exchange.server_kex_init, write_buffer); diff --git a/russh/src/server/mod.rs b/russh/src/server/mod.rs index 7f07cae8..c715dbbd 100644 --- a/russh/src/server/mod.rs +++ b/russh/src/server/mod.rs @@ -16,116 +16,35 @@ //! # Writing servers //! //! There are two ways of accepting connections: -//! * implement the [Server](server::Server) trait and let [run](server::run) handle everything +//! * implement the [Server](server::Server) trait and let [run_on_socket](server::Server::run_on_socket)/[run_on_address](server::Server::run_on_address) handle everything //! * accept connections yourself and pass them to [run_stream](server::run_stream) //! //! In both cases, you'll first need to implement the [Handler](server::Handler) trait - //! this is where you'll handle various events. //! -//! Here is an example server, which forwards input from each client -//! to all other clients: +//! Check out the following examples: //! -//! ``` -//! use async_trait::async_trait; -//! use std::sync::{Mutex, Arc}; -//! use russh::*; -//! use russh::server::{Auth, Session, Msg}; -//! use russh_keys::*; -//! use std::collections::HashMap; -//! use futures::Future; -//! -//! #[tokio::main] -//! async fn main() { -//! let client_key = russh_keys::key::KeyPair::generate_ed25519().unwrap(); -//! let client_pubkey = Arc::new(client_key.clone_public_key().unwrap()); -//! let mut config = russh::server::Config::default(); -//! config.inactivity_timeout = Some(std::time::Duration::from_secs(3)); -//! config.auth_rejection_time = std::time::Duration::from_secs(3); -//! config.keys.push(russh_keys::key::KeyPair::generate_ed25519().unwrap()); -//! let config = Arc::new(config); -//! let sh = Server{ -//! client_pubkey, -//! clients: Arc::new(Mutex::new(HashMap::new())), -//! id: 0 -//! }; -//! tokio::time::timeout( -//! std::time::Duration::from_secs(1), -//! russh::server::run(config, ("0.0.0.0", 2222), sh) -//! ).await.unwrap_or(Ok(())); -//! } -//! -//! #[derive(Clone)] -//! struct Server { -//! client_pubkey: Arc, -//! clients: Arc>>>, -//! id: usize, -//! } -//! -//! impl server::Server for Server { -//! type Handler = Self; -//! fn new_client(&mut self, _: Option) -> Self { -//! let s = self.clone(); -//! self.id += 1; -//! s -//! } -//! } -//! -//! #[async_trait] -//! impl server::Handler for Server { -//! type Error = anyhow::Error; -//! -//! async fn channel_open_session(self, channel: Channel, session: Session) -> Result<(Self, bool, Session), Self::Error> { -//! { -//! let mut clients = self.clients.lock().unwrap(); -//! clients.insert((self.id, channel.id()), channel); -//! } -//! Ok((self, true, session)) -//! } -//! async fn auth_publickey(self, _: &str, _: &key::PublicKey) -> Result<(Self, Auth), Self::Error> { -//! Ok((self, server::Auth::Accept)) -//! } -//! async fn data(self, channel: ChannelId, data: &[u8], mut session: Session) -> Result<(Self, Session), Self::Error> { -//! { -//! let mut clients = self.clients.lock().unwrap(); -//! for ((id, _channel_id), ref mut channel) in clients.iter_mut() { -//! channel.data(data); -//! } -//! } -//! Ok((self, session)) -//! } -//! } -//! ``` -//! -//! Note the call to `session.handle()`, which allows to keep a handle -//! to a client outside the event loop. This feature is internally -//! implemented using `futures::sync::mpsc` channels. -//! -//! Note that this is just a toy server. In particular: -//! -//! - It doesn't handle errors when `s.data` returns an error, i.e. when the -//! client has disappeared -//! -//! - Each new connection increments the `id` field. Even though we -//! would need a lot of connections per second for a very long time to -//! saturate it, there are probably better ways to handle this to -//! avoid collisions. +//! * [Server that forwards your input to all connected clients](https://github.com/warp-tech/russh/blob/main/russh/examples/echoserver.rs) +//! * [Server handing channel processing off to a library (here, `russh-sftp`)](https://github.com/warp-tech/russh/blob/main/russh/examples/sftp_server.rs) +//! * Serving `ratatui` based TUI app to clients: [per-client](https://github.com/warp-tech/russh/blob/main/russh/examples/ratatui_app.rs), [shared](https://github.com/warp-tech/russh/blob/main/russh/examples/ratatui_shared_app.rs) use std; -use std::collections::HashMap; +use std::collections::{HashMap, VecDeque}; +use std::num::Wrapping; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use log::error; use async_trait::async_trait; use futures::future::Future; -use russh_keys::key; +use log::{debug, error}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::net::{TcpListener, ToSocketAddrs}; use tokio::pin; use tokio::task::JoinHandle; use crate::cipher::{clear, CipherPair, OpeningKey}; +use crate::keys::key; use crate::session::*; use crate::ssh_read::*; use crate::sshbuffer::*; @@ -133,7 +52,6 @@ use crate::*; mod kex; mod session; -pub use self::kex::*; pub use self::session::*; mod encrypted; @@ -168,6 +86,10 @@ pub struct Config { pub max_auth_attempts: usize, /// Time after which the connection is garbage-collected. pub inactivity_timeout: Option, + /// If nothing is received from the client for this amount of time, send a keepalive message. + pub keepalive_interval: Option, + /// If this many keepalives have been sent without reply, close the connection. + pub keepalive_max: usize, } impl Default for Config { @@ -190,6 +112,8 @@ impl Default for Config { preferred: Default::default(), max_auth_attempts: 10, inactivity_timeout: Some(std::time::Duration::from_secs(600)), + keepalive_interval: None, + keepalive_max: 3, } } } @@ -254,13 +178,10 @@ pub trait Handler: Sized { /// sure rejection happens in time `config.auth_rejection_time`, /// except if this method takes more than that. #[allow(unused_variables)] - async fn auth_none(self, user: &str) -> Result<(Self, Auth), Self::Error> { - Ok(( - self, - Auth::Reject { - proceed_with_methods: None, - }, - )) + async fn auth_none(&mut self, user: &str) -> Result { + Ok(Auth::Reject { + proceed_with_methods: None, + }) } /// Check authentication using the "password" method. Russh @@ -268,13 +189,10 @@ pub trait Handler: Sized { /// `config.auth_rejection_time`, except if this method takes more /// than that. #[allow(unused_variables)] - async fn auth_password(self, user: &str, password: &str) -> Result<(Self, Auth), Self::Error> { - Ok(( - self, - Auth::Reject { - proceed_with_methods: None, - }, - )) + async fn auth_password(&mut self, user: &str, password: &str) -> Result { + Ok(Auth::Reject { + proceed_with_methods: None, + }) } /// Check authentication using the "publickey" method. This method @@ -285,17 +203,29 @@ pub trait Handler: Sized { /// `config.auth_rejection_time`, except if this method takes more /// time than that. #[allow(unused_variables)] + async fn auth_publickey_offered( + &mut self, + user: &str, + public_key: &key::PublicKey, + ) -> Result { + Ok(Auth::Accept) + } + + /// Check authentication using the "publickey" method. This method + /// is called after the signature has been verified and key + /// ownership has been confirmed. + /// Russh guarantees that rejection happens in constant time + /// `config.auth_rejection_time`, except if this method takes more + /// time than that. + #[allow(unused_variables)] async fn auth_publickey( - self, + &mut self, user: &str, public_key: &key::PublicKey, - ) -> Result<(Self, Auth), Self::Error> { - Ok(( - self, - Auth::Reject { - proceed_with_methods: None, - }, - )) + ) -> Result { + Ok(Auth::Reject { + proceed_with_methods: None, + }) } /// Check authentication using the "keyboard-interactive" @@ -304,97 +234,94 @@ pub trait Handler: Sized { /// than that. #[allow(unused_variables)] async fn auth_keyboard_interactive( - self, + &mut self, user: &str, submethods: &str, response: Option>, - ) -> Result<(Self, Auth), Self::Error> { - Ok(( - self, - Auth::Reject { - proceed_with_methods: None, - }, - )) + ) -> Result { + Ok(Auth::Reject { + proceed_with_methods: None, + }) } /// Called when authentication succeeds for a session. #[allow(unused_variables)] - async fn auth_succeeded(self, session: Session) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + async fn auth_succeeded(&mut self, session: &mut Session) -> Result<(), Self::Error> { + Ok(()) } /// Called when the client closes a channel. #[allow(unused_variables)] async fn channel_close( - self, + &mut self, channel: ChannelId, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> Result<(), Self::Error> { + Ok(()) } /// Called when the client sends EOF to a channel. #[allow(unused_variables)] async fn channel_eof( - self, + &mut self, channel: ChannelId, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> Result<(), Self::Error> { + Ok(()) } /// Called when a new session channel is created. /// Return value indicates whether the channel request should be granted. #[allow(unused_variables)] async fn channel_open_session( - self, + &mut self, channel: Channel, - session: Session, - ) -> Result<(Self, bool, Session), Self::Error> { - Ok((self, false, session)) + session: &mut Session, + ) -> Result { + Ok(false) } /// Called when a new X11 channel is created. /// Return value indicates whether the channel request should be granted. #[allow(unused_variables)] async fn channel_open_x11( - self, + &mut self, channel: Channel, originator_address: &str, originator_port: u32, - session: Session, - ) -> Result<(Self, bool, Session), Self::Error> { - Ok((self, false, session)) + session: &mut Session, + ) -> Result { + Ok(false) } /// Called when a new TCP/IP is created. /// Return value indicates whether the channel request should be granted. #[allow(unused_variables)] async fn channel_open_direct_tcpip( - self, + &mut self, channel: Channel, host_to_connect: &str, port_to_connect: u32, originator_address: &str, originator_port: u32, - session: Session, - ) -> Result<(Self, bool, Session), Self::Error> { - Ok((self, false, session)) + session: &mut Session, + ) -> Result { + Ok(false) } /// Called when a new forwarded connection comes in. /// #[allow(unused_variables)] async fn channel_open_forwarded_tcpip( - self, + &mut self, channel: Channel, host_to_connect: &str, port_to_connect: u32, originator_address: &str, originator_port: u32, - session: Session, - ) -> Result<(Self, bool, Session), Self::Error> { - Ok((self, false, session)) + session: &mut Session, + ) -> Result { + Ok(false) } /// Called when the client confirmed our request to open a @@ -402,25 +329,25 @@ pub trait Handler: Sized { /// message (this library panics otherwise). #[allow(unused_variables)] async fn channel_open_confirmation( - self, + &mut self, id: ChannelId, max_packet_size: u32, window_size: u32, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> Result<(), Self::Error> { + Ok(()) } /// Called when a data packet is received. A response can be /// written to the `response` argument. #[allow(unused_variables)] async fn data( - self, + &mut self, channel: ChannelId, data: &[u8], - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> Result<(), Self::Error> { + Ok(()) } /// Called when an extended data packet is received. Code 1 means @@ -429,25 +356,25 @@ pub trait Handler: Sized { /// [RFC4254](https://tools.ietf.org/html/rfc4254#section-5.2)). #[allow(unused_variables)] async fn extended_data( - self, + &mut self, channel: ChannelId, code: u32, data: &[u8], - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> Result<(), Self::Error> { + Ok(()) } /// Called when the network window is adjusted, meaning that we /// can send more bytes. #[allow(unused_variables)] async fn window_adjusted( - self, + &mut self, channel: ChannelId, new_size: u32, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> Result<(), Self::Error> { + Ok(()) } /// Called when this server adjusts the network window. Return the @@ -461,7 +388,7 @@ pub trait Handler: Sized { /// specifications. #[allow(unused_variables, clippy::too_many_arguments)] async fn pty_request( - self, + &mut self, channel: ChannelId, term: &str, col_width: u32, @@ -469,23 +396,23 @@ pub trait Handler: Sized { pix_width: u32, pix_height: u32, modes: &[(Pty, u32)], - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> Result<(), Self::Error> { + Ok(()) } /// The client requests an X11 connection. #[allow(unused_variables)] async fn x11_request( - self, + &mut self, channel: ChannelId, single_connection: bool, x11_auth_protocol: &str, x11_auth_cookie: &str, x11_screen_number: u32, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> Result<(), Self::Error> { + Ok(()) } /// The client wants to set the given environment variable. Check @@ -493,83 +420,83 @@ pub trait Handler: Sized { /// environment to be set. #[allow(unused_variables)] async fn env_request( - self, + &mut self, channel: ChannelId, variable_name: &str, variable_value: &str, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> Result<(), Self::Error> { + Ok(()) } /// The client requests a shell. #[allow(unused_variables)] async fn shell_request( - self, + &mut self, channel: ChannelId, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> Result<(), Self::Error> { + Ok(()) } /// The client sends a command to execute, to be passed to a /// shell. Make sure to check the command before doing so. #[allow(unused_variables)] async fn exec_request( - self, + &mut self, channel: ChannelId, data: &[u8], - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> Result<(), Self::Error> { + Ok(()) } /// The client asks to start the subsystem with the given name /// (such as sftp). #[allow(unused_variables)] async fn subsystem_request( - self, + &mut self, channel: ChannelId, name: &str, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> Result<(), Self::Error> { + Ok(()) } /// The client's pseudo-terminal window size has changed. #[allow(unused_variables)] async fn window_change_request( - self, + &mut self, channel: ChannelId, col_width: u32, row_height: u32, pix_width: u32, pix_height: u32, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> Result<(), Self::Error> { + Ok(()) } /// The client requests OpenSSH agent forwarding #[allow(unused_variables)] async fn agent_request( - self, + &mut self, channel: ChannelId, - session: Session, - ) -> Result<(Self, bool, Session), Self::Error> { - Ok((self, false, session)) + session: &mut Session, + ) -> Result { + Ok(false) } /// The client is sending a signal (usually to pass to the /// currently running process). #[allow(unused_variables)] async fn signal( - self, + &mut self, channel: ChannelId, signal: Sig, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> Result<(), Self::Error> { + Ok(()) } /// Used for reverse-forwarding ports, see @@ -577,56 +504,120 @@ pub trait Handler: Sized { /// If `port` is 0, you should set it to the allocated port number. #[allow(unused_variables)] async fn tcpip_forward( - self, + &mut self, address: &str, port: &mut u32, - session: Session, - ) -> Result<(Self, bool, Session), Self::Error> { - Ok((self, false, session)) + session: &mut Session, + ) -> Result { + Ok(false) } /// Used to stop the reverse-forwarding of a port, see /// [RFC4254](https://tools.ietf.org/html/rfc4254#section-7). #[allow(unused_variables)] async fn cancel_tcpip_forward( - self, + &mut self, address: &str, port: u32, - session: Session, - ) -> Result<(Self, bool, Session), Self::Error> { - Ok((self, false, session)) + session: &mut Session, + ) -> Result { + Ok(false) + } + + #[allow(unused_variables)] + async fn streamlocal_forward( + &mut self, + socket_path: &str, + session: &mut Session, + ) -> Result { + Ok(false) + } + + #[allow(unused_variables)] + async fn cancel_streamlocal_forward( + &mut self, + socket_path: &str, + session: &mut Session, + ) -> Result { + Ok(false) } } +#[async_trait] /// Trait used to create new handlers when clients connect. pub trait Server { /// The type of handlers. - type Handler: Handler + Send; + type Handler: Handler + Send + 'static; /// Called when a new client connects. fn new_client(&mut self, peer_addr: Option) -> Self::Handler; -} + /// Called when an active connection fails. + fn handle_session_error(&mut self, _error: ::Error) {} + + /// Run a server on a specified `tokio::net::TcpListener`. Useful when dropping + /// privileges immediately after socket binding, for example. + async fn run_on_socket( + &mut self, + config: Arc, + socket: &TcpListener, + ) -> Result<(), std::io::Error> { + if config.maximum_packet_size > 65535 { + error!( + "Maximum packet size ({:?}) should not larger than a TCP packet (65535)", + config.maximum_packet_size + ); + } -/// Run a server. -/// Create a new `Connection` from the server's configuration, a -/// stream and a [`Handler`](trait.Handler.html). -pub async fn run( - config: Arc, - addrs: A, - mut server: H, -) -> Result<(), std::io::Error> { - let socket = TcpListener::bind(addrs).await?; - if config.maximum_packet_size > 65535 { - error!( - "Maximum packet size ({:?}) should not larger than a TCP packet (65535)", - config.maximum_packet_size - ); - } - while let Ok((socket, _)) = socket.accept().await { - let config = config.clone(); - let server = server.new_client(socket.peer_addr().ok()); - tokio::spawn(run_stream(config, socket, server)); - } - Ok(()) + let (error_tx, mut error_rx) = tokio::sync::mpsc::unbounded_channel(); + + loop { + tokio::select! { + accept_result = socket.accept() => { + match accept_result { + Ok((socket, _)) => { + let config = config.clone(); + let handler = self.new_client(socket.peer_addr().ok()); + let error_tx = error_tx.clone(); + tokio::spawn(async move { + let session = match run_stream(config, socket, handler).await { + Ok(s) => s, + Err(e) => { + debug!("Connection setup failed"); + let _ = error_tx.send(e); + return + } + }; + match session.await { + Ok(_) => debug!("Connection closed"), + Err(e) => { + debug!("Connection closed with error"); + let _ = error_tx.send(e); + } + } + }); + } + _ => break, + } + }, + Some(error) = error_rx.recv() => { + self.handle_session_error(error); + } + } + } + + Ok(()) + } + + /// Run a server. + /// Create a new `Connection` from the server's configuration, a + /// stream and a [`Handler`](trait.Handler.html). + async fn run_on_address( + &mut self, + config: Arc, + addrs: A, + ) -> Result<(), std::io::Error> { + let socket = TcpListener::bind(addrs).await?; + self.run_on_socket(config, &socket).await + } } use std::cell::RefCell; @@ -635,14 +626,6 @@ thread_local! { static B2: RefCell = RefCell::new(CryptoVec::new()); } -pub(crate) async fn timeout(delay: Option) { - if let Some(delay) = delay { - tokio::time::sleep(delay).await - } else { - futures::future::pending().await - }; -} - async fn start_reading( mut stream_read: R, mut buffer: SSHBuffer, @@ -713,8 +696,8 @@ where pending_reads: Vec::new(), pending_len: 0, channels: HashMap::new(), + open_global_requests: VecDeque::new(), }; - let join = tokio::spawn(session.run(stream, handler)); Ok(RunningSession { handle, join }) @@ -763,14 +746,36 @@ async fn read_ssh_id( wants_reply: false, disconnected: false, buffer: CryptoVec::new(), + strict_kex: false, + alive_timeouts: 0, + received_data: false, + remote_sshid: sshid.into(), }) } +const STRICT_KEX_MSG_ORDER: &[u8] = &[msg::KEXINIT, msg::KEX_ECDH_INIT, msg::NEWKEYS]; + async fn reply( - mut session: Session, - handler: H, + session: &mut Session, + handler: &mut H, + seqn: &mut Wrapping, buf: &[u8], -) -> Result<(H, Session), H::Error> { +) -> Result<(), H::Error> { + if let Some(message_type) = buf.first() { + if session.common.strict_kex && session.common.encrypted.is_none() { + let seqno = seqn.0 - 1; // was incremented after read() + if let Some(expected) = STRICT_KEX_MSG_ORDER.get(seqno as usize) { + if message_type != expected { + return Err(strict_kex_violation(*message_type, seqno as usize).into()); + } + } + } + + if [msg::IGNORE, msg::UNIMPLEMENTED, msg::DEBUG].contains(message_type) { + return Ok(()); + } + } + // Handle key exchange/re-exchange. if session.common.encrypted.is_none() { match session.common.kex.take() { @@ -782,7 +787,14 @@ async fn reply( buf, &mut session.common.write_buffer, )?); - return Ok((handler, session)); + if let Some(Kex::Dh(KexDh { ref names, .. })) = session.common.kex { + session.common.strict_kex = names.strict_kex; + } + // seqno has already been incremented after read() + if session.common.strict_kex && seqn.0 != 1 { + return Err(strict_kex_violation(msg::KEXINIT, seqn.0 as usize - 1).into()); + } + return Ok(()); } else { // Else, i.e. if the other side has not started // the key exchange, process its packets by simple @@ -797,7 +809,11 @@ async fn reply( buf, &mut session.common.write_buffer, )?); - return Ok((handler, session)); + if let Some(Kex::Keys(_)) = session.common.kex { + // just sent NEWKEYS + session.common.maybe_reset_seqn(); + } + return Ok(()); } Some(Kex::Keys(newkeys)) => { if buf.first() != Some(&msg::NEWKEYS) { @@ -812,16 +828,19 @@ async fn reply( newkeys, ); session.maybe_send_ext_info(); - return Ok((handler, session)); + if session.common.strict_kex { + *seqn = Wrapping(0); + } + return Ok(()); } Some(kex) => { session.common.kex = Some(kex); - return Ok((handler, session)); + return Ok(()); } None => {} } - Ok((handler, session)) + Ok(()) } else { - Ok(session.server_read_encrypted(handler, buf).await?) + Ok(session.server_read_encrypted(handler, seqn, buf).await?) } } diff --git a/russh/src/server/session.rs b/russh/src/server/session.rs index 7cbcfeb3..3868a839 100644 --- a/russh/src/server/session.rs +++ b/russh/src/server/session.rs @@ -1,14 +1,16 @@ -use std::collections::HashMap; +use std::collections::{HashMap, VecDeque}; use std::sync::Arc; use log::debug; -use russh_keys::encoding::{Encoding, Reader}; +use negotiation::parse_kex_algo_list; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; -use tokio::sync::mpsc::{unbounded_channel, Receiver, Sender, UnboundedReceiver, UnboundedSender}; +use tokio::sync::mpsc::{unbounded_channel, Receiver, Sender, UnboundedReceiver}; +use tokio::sync::{oneshot, Mutex}; use super::*; -use crate::channels::{Channel, ChannelMsg}; +use crate::channels::{Channel, ChannelMsg, ChannelRef}; use crate::kex::EXTENSION_SUPPORT_AS_CLIENT; +use crate::keys::encoding::{Encoding, Reader}; use crate::msg; /// A connected server session. This type is unique to a client. @@ -19,40 +21,54 @@ pub struct Session { pub(crate) target_window_size: u32, pub(crate) pending_reads: Vec, pub(crate) pending_len: u32, - pub(crate) channels: HashMap>, + pub(crate) channels: HashMap, + pub(crate) open_global_requests: VecDeque, } #[derive(Debug)] pub enum Msg { ChannelOpenSession { - sender: UnboundedSender, + channel_ref: ChannelRef, }, ChannelOpenDirectTcpIp { host_to_connect: String, port_to_connect: u32, originator_address: String, originator_port: u32, - sender: UnboundedSender, + channel_ref: ChannelRef, }, ChannelOpenForwardedTcpIp { connected_address: String, connected_port: u32, originator_address: String, originator_port: u32, - sender: UnboundedSender, + channel_ref: ChannelRef, + }, + ChannelOpenForwardedStreamLocal { + server_socket_path: String, + channel_ref: ChannelRef, }, ChannelOpenX11 { originator_address: String, originator_port: u32, - sender: UnboundedSender, + channel_ref: ChannelRef, }, TcpIpForward { + /// Provide a channel for the reply result to request a reply from the server + reply_channel: Option>>, address: String, port: u32, }, CancelTcpIpForward { + /// Provide a channel for the reply result to request a reply from the server + reply_channel: Option>, address: String, port: u32, }, + Disconnect { + reason: crate::Disconnect, + description: String, + language_tag: String, + }, Channel(ChannelId, ChannelMsg), } @@ -148,19 +164,46 @@ impl Handle { } /// Notifies the client that it can open TCP/IP forwarding channels for a port. - pub async fn forward_tcpip(&self, address: String, port: u32) -> Result<(), ()> { + pub async fn forward_tcpip(&self, address: String, port: u32) -> Result { + let (reply_send, reply_recv) = oneshot::channel(); self.sender - .send(Msg::TcpIpForward { address, port }) + .send(Msg::TcpIpForward { + reply_channel: Some(reply_send), + address, + port, + }) .await - .map_err(|_| ()) + .map_err(|_| ())?; + + match reply_recv.await { + Ok(Some(port)) => Ok(port), + Ok(None) => Err(()), // crate::Error::RequestDenied + Err(e) => { + error!("Unable to receive TcpIpForward result: {e:?}"); + Err(()) // crate::Error::Disconnect + } + } } /// Notifies the client that it can no longer open TCP/IP forwarding channel for a port. pub async fn cancel_forward_tcpip(&self, address: String, port: u32) -> Result<(), ()> { + let (reply_send, reply_recv) = oneshot::channel(); self.sender - .send(Msg::CancelTcpIpForward { address, port }) + .send(Msg::CancelTcpIpForward { + reply_channel: Some(reply_send), + address, + port, + }) .await - .map_err(|_| ()) + .map_err(|_| ())?; + match reply_recv.await { + Ok(true) => Ok(()), + Ok(false) => Err(()), // crate::Error::RequestDenied + Err(e) => { + error!("Unable to receive CancelTcpIpForward result: {e:?}"); + Err(()) // crate::Error::Disconnect + } + } } /// Request a session channel (the most basic type of @@ -170,11 +213,16 @@ impl Handle { /// `confirmed` field of the corresponding `Channel`. pub async fn channel_open_session(&self) -> Result, Error> { let (sender, receiver) = unbounded_channel(); + let channel_ref = ChannelRef::new(sender); + let window_size_ref = channel_ref.window_size().clone(); + self.sender - .send(Msg::ChannelOpenSession { sender }) + .send(Msg::ChannelOpenSession { channel_ref }) .await .map_err(|_| Error::SendError)?; - self.wait_channel_confirmation(receiver).await + + self.wait_channel_confirmation(receiver, window_size_ref) + .await } /// Open a TCP/IP forwarding channel. This is usually done when a @@ -190,17 +238,21 @@ impl Handle { originator_port: u32, ) -> Result, Error> { let (sender, receiver) = unbounded_channel(); + let channel_ref = ChannelRef::new(sender); + let window_size_ref = channel_ref.window_size().clone(); + self.sender .send(Msg::ChannelOpenDirectTcpIp { host_to_connect: host_to_connect.into(), port_to_connect, originator_address: originator_address.into(), originator_port, - sender, + channel_ref, }) .await .map_err(|_| Error::SendError)?; - self.wait_channel_confirmation(receiver).await + self.wait_channel_confirmation(receiver, window_size_ref) + .await } pub async fn channel_open_forwarded_tcpip, B: Into>( @@ -211,17 +263,40 @@ impl Handle { originator_port: u32, ) -> Result, Error> { let (sender, receiver) = unbounded_channel(); + let channel_ref = ChannelRef::new(sender); + let window_size_ref = channel_ref.window_size().clone(); + self.sender .send(Msg::ChannelOpenForwardedTcpIp { connected_address: connected_address.into(), connected_port, originator_address: originator_address.into(), originator_port, - sender, + channel_ref, }) .await .map_err(|_| Error::SendError)?; - self.wait_channel_confirmation(receiver).await + self.wait_channel_confirmation(receiver, window_size_ref) + .await + } + + pub async fn channel_open_forwarded_streamlocal>( + &self, + server_socket_path: A, + ) -> Result, Error> { + let (sender, receiver) = unbounded_channel(); + let channel_ref = ChannelRef::new(sender); + let window_size_ref = channel_ref.window_size().clone(); + + self.sender + .send(Msg::ChannelOpenForwardedStreamLocal { + server_socket_path: server_socket_path.into(), + channel_ref, + }) + .await + .map_err(|_| Error::SendError)?; + self.wait_channel_confirmation(receiver, window_size_ref) + .await } pub async fn channel_open_x11>( @@ -230,20 +305,25 @@ impl Handle { originator_port: u32, ) -> Result, Error> { let (sender, receiver) = unbounded_channel(); + let channel_ref = ChannelRef::new(sender); + let window_size_ref = channel_ref.window_size().clone(); + self.sender .send(Msg::ChannelOpenX11 { originator_address: originator_address.into(), originator_port, - sender, + channel_ref, }) .await .map_err(|_| Error::SendError)?; - self.wait_channel_confirmation(receiver).await + self.wait_channel_confirmation(receiver, window_size_ref) + .await } async fn wait_channel_confirmation( &self, mut receiver: UnboundedReceiver, + window_size_ref: Arc>, ) -> Result, Error> { loop { match receiver.recv().await { @@ -252,12 +332,14 @@ impl Handle { max_packet_size, window_size, }) => { + *window_size_ref.lock().await = window_size; + return Ok(Channel { id, sender: self.sender.clone(), receiver, max_packet_size, - window_size, + window_size: window_size_ref, }); } Some(ChannelMsg::OpenFailure(reason)) => { @@ -295,6 +377,23 @@ impl Handle { .await .map_err(|_| ()) } + + /// Allows a server to disconnect a client session + pub async fn disconnect( + &self, + reason: Disconnect, + description: String, + language_tag: String, + ) -> Result<(), Error> { + self.sender + .send(Msg::Disconnect { + reason, + description, + language_tag, + }) + .await + .map_err(|_| Error::SendError) + } } impl Session { @@ -329,17 +428,26 @@ impl Session { let mut opening_cipher = Box::new(clear::Key) as Box; std::mem::swap(&mut opening_cipher, &mut self.common.cipher.remote_to_local); + let keepalive_timer = + future_or_pending(self.common.config.keepalive_interval, tokio::time::sleep); + pin!(keepalive_timer); + + let inactivity_timer = + future_or_pending(self.common.config.inactivity_timeout, tokio::time::sleep); + pin!(inactivity_timer); + let reading = start_reading(stream_read, buffer, opening_cipher); pin!(reading); let mut is_reading = None; let mut decomp = CryptoVec::new(); - let delay = self.common.config.inactivity_timeout; #[allow(clippy::panic)] // false positive in macro while !self.common.disconnected { + self.common.received_data = false; + let mut sent_keepalive = false; tokio::select! { r = &mut reading => { - let (stream_read, buffer, mut opening_cipher) = match r { + let (stream_read, mut buffer, mut opening_cipher) = match r { Ok((_, stream_read, buffer, opening_cipher)) => (stream_read, buffer, opening_cipher), Err(e) => return Err(e.into()) }; @@ -369,14 +477,12 @@ impl Session { debug!("break"); is_reading = Some((stream_read, buffer, opening_cipher)); break; - } else if buf[0] > 4 { + } else { + self.common.received_data = true; std::mem::swap(&mut opening_cipher, &mut self.common.cipher.remote_to_local); // TODO it'd be cleaner to just pass cipher to reply() - match reply(self, handler, buf).await { - Ok((h, s)) => { - handler = h; - self = s; - }, + match reply(&mut self, &mut handler, &mut buffer.seqn, buf).await { + Ok(_) => {}, Err(e) => return Err(e), } std::mem::swap(&mut opening_cipher, &mut self.common.cipher.remote_to_local); @@ -384,10 +490,19 @@ impl Session { } reading.set(start_reading(stream_read, buffer, opening_cipher)); } - _ = timeout(delay) => { + () = &mut keepalive_timer => { + if self.common.config.keepalive_max != 0 && self.common.alive_timeouts > self.common.config.keepalive_max { + debug!("Timeout, client not responding to keepalives"); + return Err(crate::Error::KeepaliveTimeout.into()); + } + self.common.alive_timeouts = self.common.alive_timeouts.saturating_add(1); + sent_keepalive = true; + self.keepalive_request(); + } + () = &mut inactivity_timer => { debug!("timeout"); - break - }, + return Err(crate::Error::InactivityTimeout.into()); + } msg = self.receiver.recv(), if !self.is_rekeying() => { match msg { Some(Msg::Channel(id, ChannelMsg::Data { data })) => { @@ -420,27 +535,34 @@ impl Session { Some(Msg::Channel(id, ChannelMsg::WindowAdjusted { new_size })) => { debug!("window adjusted to {:?} for channel {:?}", new_size, id); } - Some(Msg::ChannelOpenSession { sender }) => { + Some(Msg::ChannelOpenSession { channel_ref }) => { let id = self.channel_open_session()?; - self.channels.insert(id, sender); + self.channels.insert(id, channel_ref); } - Some(Msg::ChannelOpenDirectTcpIp { host_to_connect, port_to_connect, originator_address, originator_port, sender }) => { + Some(Msg::ChannelOpenDirectTcpIp { host_to_connect, port_to_connect, originator_address, originator_port, channel_ref }) => { let id = self.channel_open_direct_tcpip(&host_to_connect, port_to_connect, &originator_address, originator_port)?; - self.channels.insert(id, sender); + self.channels.insert(id, channel_ref); } - Some(Msg::ChannelOpenForwardedTcpIp { connected_address, connected_port, originator_address, originator_port, sender }) => { + Some(Msg::ChannelOpenForwardedTcpIp { connected_address, connected_port, originator_address, originator_port, channel_ref }) => { let id = self.channel_open_forwarded_tcpip(&connected_address, connected_port, &originator_address, originator_port)?; - self.channels.insert(id, sender); + self.channels.insert(id, channel_ref); } - Some(Msg::ChannelOpenX11 { originator_address, originator_port, sender }) => { + Some(Msg::ChannelOpenForwardedStreamLocal { server_socket_path, channel_ref }) => { + let id = self.channel_open_forwarded_streamlocal(&server_socket_path)?; + self.channels.insert(id, channel_ref); + } + Some(Msg::ChannelOpenX11 { originator_address, originator_port, channel_ref }) => { let id = self.channel_open_x11(&originator_address, originator_port)?; - self.channels.insert(id, sender); + self.channels.insert(id, channel_ref); + } + Some(Msg::TcpIpForward { address, port, reply_channel }) => { + self.tcpip_forward(&address, port, reply_channel); } - Some(Msg::TcpIpForward { address, port }) => { - self.tcpip_forward(&address, port); + Some(Msg::CancelTcpIpForward { address, port, reply_channel }) => { + self.cancel_tcpip_forward(&address, port, reply_channel); } - Some(Msg::CancelTcpIpForward { address, port }) => { - self.cancel_tcpip_forward(&address, port); + Some(Msg::Disconnect {reason, description, language_tag}) => { + self.common.disconnect(reason, &description, &language_tag); } Some(_) => { // should be unreachable, since the receiver only gets @@ -459,6 +581,31 @@ impl Session { .await .map_err(crate::Error::from)?; self.common.write_buffer.buffer.clear(); + + if self.common.received_data { + // Reset the number of failed keepalive attempts. We don't + // bother detecting keepalive response messages specifically + // (OpenSSH_9.6p1 responds with REQUEST_FAILURE aka 82). Instead + // we assume that the client is still alive if we receive any + // data from it. + self.common.alive_timeouts = 0; + } + if self.common.received_data || sent_keepalive { + if let (futures::future::Either::Right(ref mut sleep), Some(d)) = ( + keepalive_timer.as_mut().as_pin_mut(), + self.common.config.keepalive_interval, + ) { + sleep.as_mut().reset(tokio::time::Instant::now() + d); + } + } + if !sent_keepalive { + if let (futures::future::Either::Right(ref mut sleep), Some(d)) = ( + inactivity_timer.as_mut().as_pin_mut(), + self.common.config.inactivity_timeout, + ) { + sleep.as_mut().reset(tokio::time::Instant::now() + d); + } + } } debug!("disconnected"); // Shutdown @@ -701,6 +848,20 @@ impl Session { } } + /// Ping the client to verify there is still connectivity. + pub fn keepalive_request(&mut self) { + let want_reply = u8::from(true); + if let Some(ref mut enc) = self.common.encrypted { + self.open_global_requests + .push_back(GlobalRequestResponse::Keepalive); + push_packet!(enc.write, { + enc.write.push(msg::GLOBAL_REQUEST); + enc.write.extend_ssh_string(b"keepalive@openssh.com"); + enc.write.push(want_reply); + }) + } + } + /// Send the exit status of a program. pub fn exit_status_request(&mut self, channel: ChannelId, exit_status: u32) { if let Some(ref mut enc) = self.common.encrypted { @@ -786,6 +947,16 @@ impl Session { }) } + pub fn channel_open_forwarded_streamlocal( + &mut self, + socket_path: &str, + ) -> Result { + self.channel_open_generic(b"forwarded-streamlocal@openssh.com", |write| { + write.extend_ssh_string(socket_path.as_bytes()); + write.extend_ssh_string(b""); + }) + } + /// Open a new X11 channel, when a connection comes to a /// local port. See [RFC4254](https://tools.ietf.org/html/rfc4254#section-6.3.2). /// TCP/IP packets can then be tunneled through the channel using `.data()`. @@ -848,12 +1019,23 @@ impl Session { /// Requests that the client forward connections to the given host and port. /// See [RFC4254](https://tools.ietf.org/html/rfc4254#section-7). The client /// will open forwarded_tcpip channels for each connection. - pub fn tcpip_forward(&mut self, address: &str, port: u32) { + pub fn tcpip_forward( + &mut self, + address: &str, + port: u32, + reply_channel: Option>>, + ) { if let Some(ref mut enc) = self.common.encrypted { + let want_reply = reply_channel.is_some(); + if let Some(reply_channel) = reply_channel { + self.open_global_requests.push_back( + crate::session::GlobalRequestResponse::TcpIpForward(reply_channel), + ); + } push_packet!(enc.write, { enc.write.push(msg::GLOBAL_REQUEST); enc.write.extend_ssh_string(b"tcpip-forward"); - enc.write.push(0); + enc.write.push(want_reply as u8); enc.write.extend_ssh_string(address.as_bytes()); enc.write.push_u32_be(port); }); @@ -861,18 +1043,42 @@ impl Session { } /// Cancels a previously tcpip_forward request. - pub fn cancel_tcpip_forward(&mut self, address: &str, port: u32) { + pub fn cancel_tcpip_forward( + &mut self, + address: &str, + port: u32, + reply_channel: Option>, + ) { if let Some(ref mut enc) = self.common.encrypted { + let want_reply = reply_channel.is_some(); + if let Some(reply_channel) = reply_channel { + self.open_global_requests.push_back( + crate::session::GlobalRequestResponse::CancelTcpIpForward(reply_channel), + ); + } push_packet!(enc.write, { enc.write.push(msg::GLOBAL_REQUEST); enc.write.extend_ssh_string(b"cancel-tcpip-forward"); - enc.write.push(0); + enc.write.push(want_reply as u8); enc.write.extend_ssh_string(address.as_bytes()); enc.write.push_u32_be(port); }); } } + /// Returns the SSH ID (Protocol Version + Software Version) the client sent when connecting + /// + /// This should contain only ASCII characters for implementations conforming to RFC4253, Section 4.2: + /// + /// > Both the 'protoversion' and 'softwareversion' strings MUST consist of + /// > printable US-ASCII characters, with the exception of whitespace + /// > characters and the minus sign (-). + /// + /// So it usually is fine to convert it to a [`String`] using [`String::from_utf8_lossy`] + pub fn remote_sshid(&self) -> &[u8] { + &self.common.remote_sshid + } + pub(crate) fn maybe_send_ext_info(&mut self) { if let Some(ref mut enc) = self.common.encrypted { // If client sent a ext-info-c message in the kex list, it supports RFC 8308 extension negotiation. @@ -883,9 +1089,10 @@ impl Session { use super::negotiation::Select; key_extension_client = super::negotiation::Server::select( &[EXTENSION_SUPPORT_AS_CLIENT], - kex_string, + &parse_kex_algo_list(kex_string), + AlgorithmKind::Kex, ) - .is_some(); + .is_ok(); } } diff --git a/russh/src/session.rs b/russh/src/session.rs index 09afa95a..0a1f633b 100644 --- a/russh/src/session.rs +++ b/russh/src/session.rs @@ -19,13 +19,15 @@ use std::num::Wrapping; use byteorder::{BigEndian, ByteOrder}; use log::{debug, trace}; -use russh_cryptovec::CryptoVec; -use russh_keys::encoding::Encoding; +use tokio::sync::oneshot; use crate::cipher::SealingKey; use crate::kex::KexAlgorithm; +use crate::keys::encoding::Encoding; use crate::sshbuffer::SSHBuffer; -use crate::{auth, cipher, mac, msg, negotiation, ChannelId, ChannelParams, Disconnect, Limits}; +use crate::{ + auth, cipher, mac, msg, negotiation, ChannelId, ChannelParams, CryptoVec, Disconnect, Limits, +}; #[derive(Debug)] pub(crate) struct Encrypted { @@ -53,6 +55,7 @@ pub(crate) struct Encrypted { pub(crate) struct CommonSession { pub auth_user: String, + pub remote_sshid: Vec, pub config: Config, pub encrypted: Option, pub auth_method: Option, @@ -63,6 +66,36 @@ pub(crate) struct CommonSession { pub wants_reply: bool, pub disconnected: bool, pub buffer: CryptoVec, + pub strict_kex: bool, + pub alive_timeouts: usize, + pub received_data: bool, +} + +#[derive(Debug, Clone, Copy)] +pub(crate) enum ChannelFlushResult { + Incomplete { + wrote: usize, + }, + Complete { + wrote: usize, + pending_eof: bool, + pending_close: bool, + }, +} +impl ChannelFlushResult { + pub(crate) fn wrote(&self) -> usize { + match self { + ChannelFlushResult::Incomplete { wrote } => *wrote, + ChannelFlushResult::Complete { wrote, .. } => *wrote, + } + } + pub(crate) fn complete(wrote: usize, channel: &ChannelParams) -> Self { + ChannelFlushResult::Complete { + wrote, + pending_eof: channel.pending_eof, + pending_close: channel.pending_close, + } + } } impl CommonSession { @@ -74,6 +107,7 @@ impl CommonSession { enc.client_mac = newkeys.names.client_mac; enc.server_mac = newkeys.names.server_mac; self.cipher = newkeys.cipher; + self.strict_kex = self.strict_kex || newkeys.names.strict_kex; } } @@ -99,6 +133,7 @@ impl CommonSession { decompress: crate::compression::Decompress::None, }); self.cipher = newkeys.cipher; + self.strict_kex = newkeys.names.strict_kex; } /// Send a disconnect message. @@ -127,6 +162,12 @@ impl CommonSession { enc.byte(channel, msg) } } + + pub(crate) fn maybe_reset_seqn(&mut self) { + if self.strict_kex { + self.write_buffer.seqn = Wrapping(0); + } + } } impl Encrypted { @@ -147,12 +188,20 @@ impl Encrypted { */ pub fn eof(&mut self, channel: ChannelId) { - self.byte(channel, msg::CHANNEL_EOF); + if let Some(channel) = self.has_pending_data_mut(channel) { + channel.pending_eof = true; + } else { + self.byte(channel, msg::CHANNEL_EOF); + } } pub fn close(&mut self, channel: ChannelId) { - self.byte(channel, msg::CHANNEL_CLOSE); - self.channels.remove(&channel); + if let Some(channel) = self.has_pending_data_mut(channel) { + channel.pending_close = true; + } else { + self.byte(channel, msg::CHANNEL_CLOSE); + self.channels.remove(&channel); + } } pub fn sender_window_size(&self, channel: ChannelId) -> usize { @@ -192,33 +241,64 @@ impl Encrypted { false } + fn flush_channel(write: &mut CryptoVec, channel: &mut ChannelParams) -> ChannelFlushResult { + let mut pending_size = 0; + while let Some((buf, a, from)) = channel.pending_data.pop_front() { + let size = Self::data_noqueue(write, channel, &buf, a, from); + pending_size += size; + if from + size < buf.len() { + channel.pending_data.push_front((buf, a, from + size)); + return ChannelFlushResult::Incomplete { + wrote: pending_size, + }; + } + } + ChannelFlushResult::complete(pending_size, channel) + } + + fn handle_flushed_channel(&mut self, channel: ChannelId, flush_result: ChannelFlushResult) { + if let ChannelFlushResult::Complete { + wrote: _, + pending_eof, + pending_close, + } = flush_result + { + if pending_eof { + self.eof(channel); + } + if pending_close { + self.close(channel); + } + } + } + pub fn flush_pending(&mut self, channel: ChannelId) -> usize { let mut pending_size = 0; + let mut maybe_flush_result = Option::::None; + if let Some(channel) = self.channels.get_mut(&channel) { - while let Some((buf, a, from)) = channel.pending_data.pop_front() { - let size = Self::data_noqueue(&mut self.write, channel, &buf, from); - pending_size += size; - if from + size < buf.len() { - channel.pending_data.push_front((buf, a, from + size)); - break; - } - } + let flush_result = Self::flush_channel(&mut self.write, channel); + pending_size += flush_result.wrote(); + maybe_flush_result = Some(flush_result); + } + if let Some(flush_result) = maybe_flush_result { + self.handle_flushed_channel(channel, flush_result) } pending_size } pub fn flush_all_pending(&mut self) { - for (_, channel) in self.channels.iter_mut() { - while let Some((buf, a, from)) = channel.pending_data.pop_front() { - let size = Self::data_noqueue(&mut self.write, channel, &buf, from); - if from + size < buf.len() { - channel.pending_data.push_front((buf, a, from + size)); - break; - } - } + for channel in self.channels.values_mut() { + Self::flush_channel(&mut self.write, channel); } } + fn has_pending_data_mut(&mut self, channel: ChannelId) -> Option<&mut ChannelParams> { + self.channels + .get_mut(&channel) + .filter(|c| !c.pending_data.is_empty()) + } + pub fn has_pending_data(&self, channel: ChannelId) -> bool { if let Some(channel) = self.channels.get(&channel) { !channel.pending_data.is_empty() @@ -234,6 +314,7 @@ impl Encrypted { write: &mut CryptoVec, channel: &mut ChannelParams, buf0: &[u8], + a: Option, from: usize, ) -> usize { if from >= buf0.len() { @@ -251,12 +332,21 @@ impl Encrypted { while !buf.is_empty() { // Compute the length we're allowed to send. let off = std::cmp::min(buf.len(), channel.recipient_maximum_packet_size as usize); - push_packet!(write, { - write.push(msg::CHANNEL_DATA); - write.push_u32_be(channel.recipient_channel); - #[allow(clippy::indexing_slicing)] // length checked - write.extend_ssh_string(&buf[..off]); - }); + match a { + None => push_packet!(write, { + write.push(msg::CHANNEL_DATA); + write.push_u32_be(channel.recipient_channel); + #[allow(clippy::indexing_slicing)] // length checked + write.extend_ssh_string(&buf[..off]); + }), + Some(ext) => push_packet!(write, { + write.push(msg::CHANNEL_EXTENDED_DATA); + write.push_u32_be(channel.recipient_channel); + write.push_u32_be(ext); + #[allow(clippy::indexing_slicing)] // length checked + write.extend_ssh_string(&buf[..off]); + }), + } trace!( "buffer: {:?} {:?}", write.len(), @@ -279,7 +369,7 @@ impl Encrypted { channel.pending_data.push_back((buf0, None, 0)); return; } - let buf_len = Self::data_noqueue(&mut self.write, channel, &buf0, 0); + let buf_len = Self::data_noqueue(&mut self.write, channel, &buf0, None, 0); if buf_len < buf0.len() { channel.pending_data.push_back((buf0, None, buf_len)) } @@ -289,39 +379,13 @@ impl Encrypted { } pub fn extended_data(&mut self, channel: ChannelId, ext: u32, buf0: CryptoVec) { - use std::ops::Deref; if let Some(channel) = self.channels.get_mut(&channel) { assert!(channel.confirmed); if !channel.pending_data.is_empty() { channel.pending_data.push_back((buf0, Some(ext), 0)); return; } - let mut buf = if buf0.len() as u32 > channel.recipient_window_size { - #[allow(clippy::indexing_slicing)] // length checked - &buf0[0..channel.recipient_window_size as usize] - } else { - &buf0 - }; - let buf_len = buf.len(); - - while !buf.is_empty() { - // Compute the length we're allowed to send. - let off = std::cmp::min(buf.len(), channel.recipient_maximum_packet_size as usize); - push_packet!(self.write, { - self.write.push(msg::CHANNEL_EXTENDED_DATA); - self.write.push_u32_be(channel.recipient_channel); - self.write.push_u32_be(ext); - #[allow(clippy::indexing_slicing)] // length checked - self.write.extend_ssh_string(&buf[..off]); - }); - trace!("buffer: {:?}", self.write.deref().len()); - channel.recipient_window_size -= off as u32; - #[allow(clippy::indexing_slicing)] // length checked - { - buf = &buf[off..] - } - } - trace!("buf.len() = {:?}, buf_len = {:?}", buf.len(), buf_len); + let buf_len = Self::data_noqueue(&mut self.write, channel, &buf0, Some(ext), 0); if buf_len < buf0.len() { channel.pending_data.push_back((buf0, Some(ext), buf_len)) } @@ -391,6 +455,8 @@ impl Encrypted { confirmed: false, wants_reply: false, pending_data: std::collections::VecDeque::new(), + pending_eof: false, + pending_close: false, }); return ChannelId(self.last_channel_id.0); } @@ -406,7 +472,7 @@ pub enum EncryptedState { Authenticated, } -#[derive(Debug, Default)] +#[derive(Debug, Default, Clone)] pub struct Exchange { pub client_id: CryptoVec, pub server_id: CryptoVec, @@ -552,3 +618,15 @@ pub(crate) struct NewKeys { pub session_id: CryptoVec, pub sent: bool, } + +pub(crate) enum GlobalRequestResponse { + /// request was for Keepalive, ignore result + Keepalive, + /// request was for TcpIpForward, sends Some(port) for success or None for failure + TcpIpForward(oneshot::Sender>), + /// request was for CancelTcpIpForward, sends true for success or false for failure + CancelTcpIpForward(oneshot::Sender), + /// request was for StreamLocalForward, sends true for success or false for failure + StreamLocalForward(oneshot::Sender), + CancelStreamLocalForward(oneshot::Sender), +} diff --git a/russh/src/ssh_read.rs b/russh/src/ssh_read.rs index d74b19dd..072642c6 100644 --- a/russh/src/ssh_read.rs +++ b/russh/src/ssh_read.rs @@ -1,11 +1,10 @@ use std::pin::Pin; use futures::task::*; -use russh_cryptovec::CryptoVec; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf}; use log::debug; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf}; -use crate::Error; +use crate::{CryptoVec, Error}; /// The buffer to read the identification string (first line in the /// protocol). diff --git a/russh/src/tests.rs b/russh/src/tests.rs new file mode 100644 index 00000000..2efcbe40 --- /dev/null +++ b/russh/src/tests.rs @@ -0,0 +1,473 @@ +#![allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)] // Allow unwraps, expects and panics in the test suite + +use futures::Future; + +use super::*; + +mod compress { + use std::collections::HashMap; + use std::sync::{Arc, Mutex}; + + use async_trait::async_trait; + use log::debug; + + use super::server::{Server as _, Session}; + use super::*; + use crate::server::Msg; + + #[tokio::test] + async fn compress_local_test() { + let _ = env_logger::try_init(); + + let client_key = russh_keys::key::KeyPair::generate_ed25519().unwrap(); + let mut config = server::Config::default(); + config.preferred = Preferred::COMPRESSED; + config.inactivity_timeout = None; // Some(std::time::Duration::from_secs(3)); + config.auth_rejection_time = std::time::Duration::from_secs(3); + config + .keys + .push(russh_keys::key::KeyPair::generate_ed25519().unwrap()); + let config = Arc::new(config); + let mut sh = Server { + clients: Arc::new(Mutex::new(HashMap::new())), + id: 0, + }; + + let socket = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = socket.local_addr().unwrap(); + + tokio::spawn(async move { + let (socket, _) = socket.accept().await.unwrap(); + let server = sh.new_client(socket.peer_addr().ok()); + server::run_stream(config, socket, server).await.unwrap(); + }); + + let mut config = client::Config::default(); + config.preferred = Preferred::COMPRESSED; + let config = Arc::new(config); + + dbg!(&addr); + let mut session = client::connect(config, addr, Client {}).await.unwrap(); + let authenticated = session + .authenticate_publickey( + std::env::var("USER").unwrap_or("user".to_owned()), + Arc::new(client_key), + ) + .await + .unwrap(); + assert!(authenticated); + let mut channel = session.channel_open_session().await.unwrap(); + + let data = &b"Hello, world!"[..]; + channel.data(data).await.unwrap(); + let msg = channel.wait().await.unwrap(); + match msg { + ChannelMsg::Data { data: msg_data } => { + assert_eq!(*data, *msg_data) + } + msg => panic!("Unexpected message {:?}", msg), + } + } + + #[derive(Clone)] + struct Server { + clients: Arc>>, + id: usize, + } + + impl server::Server for Server { + type Handler = Self; + fn new_client(&mut self, _: Option) -> Self { + let s = self.clone(); + self.id += 1; + s + } + } + + #[async_trait] + impl server::Handler for Server { + type Error = super::Error; + + async fn channel_open_session( + &mut self, + channel: Channel, + session: &mut Session, + ) -> Result { + { + let mut clients = self.clients.lock().unwrap(); + clients.insert((self.id, channel.id()), session.handle()); + } + Ok(true) + } + async fn auth_publickey( + &mut self, + _: &str, + _: &russh_keys::key::PublicKey, + ) -> Result { + debug!("auth_publickey"); + Ok(server::Auth::Accept) + } + async fn data( + &mut self, + channel: ChannelId, + data: &[u8], + session: &mut Session, + ) -> Result<(), Self::Error> { + debug!("server data = {:?}", std::str::from_utf8(data)); + session.data(channel, CryptoVec::from_slice(data)); + Ok(()) + } + } + + struct Client {} + + #[async_trait] + impl client::Handler for Client { + type Error = super::Error; + + async fn check_server_key( + &mut self, + _server_public_key: &russh_keys::key::PublicKey, + ) -> Result { + // println!("check_server_key: {:?}", server_public_key); + Ok(true) + } + } +} + +mod channels { + use async_trait::async_trait; + use server::Session; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + use super::*; + use crate::CryptoVec; + + async fn test_session( + client_handler: CH, + server_handler: SH, + run_client: RC, + run_server: RS, + ) where + RC: FnOnce(crate::client::Handle) -> F1 + Send + Sync + 'static, + RS: FnOnce(crate::server::Handle) -> F2 + Send + Sync + 'static, + F1: Future> + Send + Sync + 'static, + F2: Future + Send + Sync + 'static, + CH: crate::client::Handler + Send + Sync + 'static, + SH: crate::server::Handler + Send + Sync + 'static, + { + use std::sync::Arc; + + use crate::*; + + let _ = env_logger::try_init(); + + let client_key = russh_keys::key::KeyPair::generate_ed25519().unwrap(); + let mut config = server::Config::default(); + config.inactivity_timeout = None; + config.auth_rejection_time = std::time::Duration::from_secs(3); + config + .keys + .push(russh_keys::key::KeyPair::generate_ed25519().unwrap()); + let config = Arc::new(config); + let socket = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = socket.local_addr().unwrap(); + + #[derive(Clone)] + struct Server {} + + let server_join = tokio::spawn(async move { + let (socket, _) = socket.accept().await.unwrap(); + + server::run_stream(config, socket, server_handler) + .await + .map_err(|_| ()) + .unwrap() + }); + + let client_join = tokio::spawn(async move { + let config = Arc::new(client::Config::default()); + let mut session = client::connect(config, addr, client_handler) + .await + .map_err(|_| ()) + .unwrap(); + let authenticated = session + .authenticate_publickey( + std::env::var("USER").unwrap_or("user".to_owned()), + Arc::new(client_key), + ) + .await + .unwrap(); + assert!(authenticated); + session + }); + + let (server_session, client_session) = tokio::join!(server_join, client_join); + let client_handle = tokio::spawn(run_client(client_session.unwrap())); + let server_handle = tokio::spawn(run_server(server_session.unwrap().handle())); + + let (server_session, client_session) = tokio::join!(server_handle, client_handle); + drop(client_session); + drop(server_session); + } + + #[tokio::test] + async fn test_server_channels() { + #[derive(Debug)] + struct Client {} + + #[async_trait] + impl client::Handler for Client { + type Error = crate::Error; + + async fn check_server_key( + &mut self, + _server_public_key: &russh_keys::key::PublicKey, + ) -> Result { + Ok(true) + } + + async fn data( + &mut self, + channel: ChannelId, + data: &[u8], + session: &mut client::Session, + ) -> Result<(), Self::Error> { + assert_eq!(data, &b"hello world!"[..]); + session.data(channel, CryptoVec::from_slice(&b"hey there!"[..])); + Ok(()) + } + } + + struct ServerHandle { + did_auth: Option>, + } + + impl ServerHandle { + fn get_auth_waiter(&mut self) -> tokio::sync::oneshot::Receiver<()> { + let (tx, rx) = tokio::sync::oneshot::channel(); + self.did_auth = Some(tx); + rx + } + } + + #[async_trait] + impl server::Handler for ServerHandle { + type Error = crate::Error; + + async fn auth_publickey( + &mut self, + _: &str, + _: &russh_keys::key::PublicKey, + ) -> Result { + Ok(server::Auth::Accept) + } + async fn auth_succeeded(&mut self, _session: &mut Session) -> Result<(), Self::Error> { + if let Some(a) = self.did_auth.take() { + a.send(()).unwrap(); + } + Ok(()) + } + } + + let mut sh = ServerHandle { did_auth: None }; + let a = sh.get_auth_waiter(); + test_session( + Client {}, + sh, + |c| async move { c }, + |s| async move { + a.await.unwrap(); + let mut ch = s.channel_open_session().await.unwrap(); + ch.data(&b"hello world!"[..]).await.unwrap(); + + let msg = ch.wait().await.unwrap(); + if let ChannelMsg::Data { data } = msg { + assert_eq!(data.as_ref(), &b"hey there!"[..]); + } else { + panic!("Unexpected message {:?}", msg); + } + s + }, + ) + .await; + } + + #[tokio::test] + async fn test_channel_streams() { + #[derive(Debug)] + struct Client {} + + #[async_trait] + impl client::Handler for Client { + type Error = crate::Error; + + async fn check_server_key( + &mut self, + _server_public_key: &russh_keys::key::PublicKey, + ) -> Result { + Ok(true) + } + } + + struct ServerHandle { + channel: Option>>, + } + + impl ServerHandle { + fn get_channel_waiter( + &mut self, + ) -> tokio::sync::oneshot::Receiver> { + let (tx, rx) = tokio::sync::oneshot::channel::>(); + self.channel = Some(tx); + rx + } + } + + #[async_trait] + impl server::Handler for ServerHandle { + type Error = crate::Error; + + async fn auth_publickey( + &mut self, + _: &str, + _: &russh_keys::key::PublicKey, + ) -> Result { + Ok(server::Auth::Accept) + } + + async fn channel_open_session( + &mut self, + channel: Channel, + _session: &mut server::Session, + ) -> Result { + if let Some(a) = self.channel.take() { + println!("channel open session {:?}", a); + a.send(channel).unwrap(); + } + Ok(true) + } + } + + let mut sh = ServerHandle { channel: None }; + let scw = sh.get_channel_waiter(); + + test_session( + Client {}, + sh, + |client| async move { + let ch = client.channel_open_session().await.unwrap(); + let mut stream = ch.into_stream(); + stream.write_all(&b"request"[..]).await.unwrap(); + + let mut buf = Vec::new(); + stream.read_buf(&mut buf).await.unwrap(); + assert_eq!(&buf, &b"response"[..]); + + stream.write_all(&b"reply"[..]).await.unwrap(); + + client + }, + |server| async move { + let channel = scw.await.unwrap(); + let mut stream = channel.into_stream(); + + let mut buf = Vec::new(); + stream.read_buf(&mut buf).await.unwrap(); + assert_eq!(&buf, &b"request"[..]); + + stream.write_all(&b"response"[..]).await.unwrap(); + + buf.clear(); + + stream.read_buf(&mut buf).await.unwrap(); + assert_eq!(&buf, &b"reply"[..]); + + server + }, + ) + .await; + } + + #[tokio::test] + async fn test_channel_objects() { + #[derive(Debug)] + struct Client {} + + #[async_trait] + impl client::Handler for Client { + type Error = crate::Error; + + async fn check_server_key( + &mut self, + _server_public_key: &russh_keys::key::PublicKey, + ) -> Result { + Ok(true) + } + } + + struct ServerHandle {} + + impl ServerHandle {} + + #[async_trait] + impl server::Handler for ServerHandle { + type Error = crate::Error; + + async fn auth_publickey( + &mut self, + _: &str, + _: &russh_keys::key::PublicKey, + ) -> Result { + Ok(server::Auth::Accept) + } + + async fn channel_open_session( + &mut self, + mut channel: Channel, + _session: &mut Session, + ) -> Result { + tokio::spawn(async move { + while let Some(msg) = channel.wait().await { + match msg { + ChannelMsg::Data { data } => { + channel.data(&data[..]).await.unwrap(); + channel.close().await.unwrap(); + break; + } + _ => {} + } + } + }); + Ok(true) + } + } + + let sh = ServerHandle {}; + test_session( + Client {}, + sh, + |c| async move { + let mut ch = c.channel_open_session().await.unwrap(); + ch.data(&b"hello world!"[..]).await.unwrap(); + + let msg = ch.wait().await.unwrap(); + if let ChannelMsg::Data { data } = msg { + assert_eq!(data.as_ref(), &b"hey there!"[..]); + } else { + panic!("Unexpected message {:?}", msg); + } + + let msg = ch.wait().await.unwrap(); + let ChannelMsg::Close = msg else { + panic!("Unexpected message {:?}", msg); + }; + + ch.close().await.unwrap(); + c + }, + |s| async move { s }, + ) + .await; + } +} diff --git a/russh/tests/test_data_stream.rs b/russh/tests/test_data_stream.rs new file mode 100644 index 00000000..3d909961 --- /dev/null +++ b/russh/tests/test_data_stream.rs @@ -0,0 +1,142 @@ +use std::net::{SocketAddr, TcpListener, TcpStream}; +use std::sync::Arc; + +use rand::RngCore; +use russh::keys::key; +use russh::server::{self, Auth, Msg, Server as _, Session}; +use russh::{client, Channel}; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; + +pub const WINDOW_SIZE: u32 = 8 * 2048; + +#[tokio::test] +async fn test_reader_and_writer() -> Result<(), anyhow::Error> { + env_logger::init(); + + let addr = addr(); + let data = data(); + + tokio::spawn(Server::run(addr)); + + // Wait until the server is started + while TcpStream::connect(addr).is_err() { + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + } + + stream(addr, &data).await?; + + Ok(()) +} + +async fn stream(addr: SocketAddr, data: &[u8]) -> Result<(), anyhow::Error> { + let config = Arc::new(client::Config::default()); + let key = Arc::new(russh_keys::key::KeyPair::generate_ed25519().unwrap()); + + let mut session = russh::client::connect(config, addr, Client).await?; + let mut channel = match session.authenticate_publickey("user", key).await { + Ok(true) => session.channel_open_session().await?, + Ok(false) => panic!("Authentication failed"), + Err(err) => return Err(err.into()), + }; + + let mut buf = Vec::::new(); + let (mut writer, mut reader) = (channel.make_writer_ext(Some(1)), channel.make_reader()); + + let (r0, r1) = tokio::join!( + async { + writer.write_all(data).await?; + writer.shutdown().await?; + + Ok::<_, anyhow::Error>(()) + }, + reader.read_to_end(&mut buf) + ); + + r0?; + let count = r1?; + + assert_eq!(data.len(), count); + assert_eq!(data, buf); + + Ok(()) +} + +fn data() -> Vec { + let mut rng = rand::thread_rng(); + + let mut data = vec![0u8; WINDOW_SIZE as usize * 2 + 7]; // Check whether the window_size resizing works + rng.fill_bytes(&mut data); + + data +} + +/// Find a unused local address to bind our server to +fn addr() -> SocketAddr { + TcpListener::bind(("127.0.0.1", 0)) + .unwrap() + .local_addr() + .unwrap() +} + +#[derive(Clone)] +struct Server; + +impl Server { + async fn run(addr: SocketAddr) { + let config = Arc::new(server::Config { + keys: vec![russh_keys::key::KeyPair::generate_ed25519().unwrap()], + window_size: WINDOW_SIZE, + ..Default::default() + }); + let mut sh = Server {}; + + sh.run_on_address(config, addr).await.unwrap(); + } +} + +impl russh::server::Server for Server { + type Handler = Self; + + fn new_client(&mut self, _: Option) -> Self::Handler { + self.clone() + } +} + +#[async_trait::async_trait] +impl russh::server::Handler for Server { + type Error = anyhow::Error; + + async fn auth_publickey(&mut self, _: &str, _: &key::PublicKey) -> Result { + Ok(Auth::Accept) + } + + async fn channel_open_session( + &mut self, + mut channel: Channel, + _session: &mut Session, + ) -> Result { + tokio::spawn(async move { + let (mut writer, mut reader) = + (channel.make_writer(), channel.make_reader_ext(Some(1))); + + tokio::io::copy(&mut reader, &mut writer) + .await + .expect("Data transfer failed"); + + writer.shutdown().await.expect("Shutdown failed"); + }); + + Ok(true) + } +} + +struct Client; + +#[async_trait::async_trait] +impl russh::client::Handler for Client { + type Error = anyhow::Error; + + async fn check_server_key(&mut self, _: &key::PublicKey) -> Result { + Ok(true) + } +} diff --git a/rust-toolchain.toml b/rust-toolchain.toml index f8c2abbb..624eb0ea 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,2 +1,2 @@ [toolchain] -channel = "1.65.0" +channel = "1.76.0"