diff --git a/windows_kext/driver/Cargo.lock b/windows_kext/driver/Cargo.lock index b87467456..4a0e0b397 100644 --- a/windows_kext/driver/Cargo.lock +++ b/windows_kext/driver/Cargo.lock @@ -2,18 +2,6 @@ # It is not intended for manual editing. version = 3 -[[package]] -name = "ahash" -version = "0.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91429305e9f0a25f6205c5b8e0d2db09e0708a7a6df0f42212bb56c32c8ac97a" -dependencies = [ - "cfg-if", - "once_cell", - "version_check", - "zerocopy", -] - [[package]] name = "atomic-polyfill" version = "1.0.3" @@ -57,7 +45,6 @@ checksum = "7059fff8937831a9ae6f0fe4d658ffabf58f2ca96aa9dec1c889f936f705f216" name = "driver" version = "0.0.0" dependencies = [ - "hashbrown", "num", "num-derive", "num-traits", @@ -76,15 +63,6 @@ dependencies = [ "byteorder", ] -[[package]] -name = "hashbrown" -version = "0.14.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" -dependencies = [ - "ahash", -] - [[package]] name = "heapless" version = "0.7.17" @@ -217,12 +195,6 @@ dependencies = [ "syn", ] -[[package]] -name = "once_cell" -version = "1.19.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" - [[package]] name = "proc-macro2" version = "1.0.78" @@ -316,12 +288,6 @@ version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" -[[package]] -name = "version_check" -version = "0.9.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" - [[package]] name = "wdk" version = "0.0.0" @@ -399,23 +365,3 @@ source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e name = "windows_x86_64_msvc" version = "0.52.5" source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317" - -[[package]] -name = "zerocopy" -version = "0.7.28" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d6f15f7ade05d2a4935e34a457b936c23dc70a05cc1d97133dc99e7a3fe0f0e" -dependencies = [ - "zerocopy-derive", -] - -[[package]] -name = "zerocopy-derive" -version = "0.7.28" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dbbad221e3f78500350ecbd7dfa4e63ef945c05f4c61cb7f4d3f84cd0bba649b" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] diff --git a/windows_kext/driver/Cargo.toml b/windows_kext/driver/Cargo.toml index 66dffacab..034627c0b 100644 --- a/windows_kext/driver/Cargo.toml +++ b/windows_kext/driver/Cargo.toml @@ -17,7 +17,6 @@ num = { version = "0.4", default-features = false } num-derive = { version = "0.4", default-features = false } num-traits = { version = "0.2", default-features = false } smoltcp = { version = "0.10", default-features = false, features = ["proto-ipv4", "proto-ipv6"] } -hashbrown = { version = "0.14.3", default-features = false, features = ["ahash"]} # WARNING: Do not update. The version was choosen for a reason. See wdk/README.md for more detiels. [dependencies.windows-sys] diff --git a/windows_kext/driver/src/bandwidth.rs b/windows_kext/driver/src/bandwidth.rs index 4fb487867..0105ac72e 100644 --- a/windows_kext/driver/src/bandwidth.rs +++ b/windows_kext/driver/src/bandwidth.rs @@ -1,14 +1,10 @@ +use alloc::collections::BTreeMap; use protocol::info::{BandwidthValueV4, BandwidthValueV6, Info}; use smoltcp::wire::{IpProtocol, Ipv4Address, Ipv6Address}; use wdk::rw_spin_lock::RwSpinLock; -use crate::driver_hashmap::DeviceHashMap; - -#[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Default)] -pub struct Key
-where - Address: Eq + PartialEq, -{ +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Default)] +pub struct Key { pub local_ip: Address, pub local_port: u16, pub remote_ip: Address, @@ -25,32 +21,32 @@ enum Direction { Rx(usize), } pub struct Bandwidth { - stats_tcp_v4: DeviceHashMap, Value>, + stats_tcp_v4: BTreeMap, Value>, stats_tcp_v4_lock: RwSpinLock, - stats_tcp_v6: DeviceHashMap, Value>, + stats_tcp_v6: BTreeMap, Value>, stats_tcp_v6_lock: RwSpinLock, - stats_udp_v4: DeviceHashMap, Value>, + stats_udp_v4: BTreeMap, Value>, stats_udp_v4_lock: RwSpinLock, - stats_udp_v6: DeviceHashMap, Value>, + stats_udp_v6: BTreeMap, Value>, stats_udp_v6_lock: RwSpinLock, } impl Bandwidth { pub fn new() -> Self { Self { - stats_tcp_v4: DeviceHashMap::new(), + stats_tcp_v4: BTreeMap::new(), stats_tcp_v4_lock: RwSpinLock::default(), - stats_tcp_v6: DeviceHashMap::new(), + stats_tcp_v6: BTreeMap::new(), stats_tcp_v6_lock: RwSpinLock::default(), - stats_udp_v4: DeviceHashMap::new(), + stats_udp_v4: BTreeMap::new(), stats_udp_v4_lock: RwSpinLock::default(), - stats_udp_v6: DeviceHashMap::new(), + stats_udp_v6: BTreeMap::new(), stats_udp_v6_lock: RwSpinLock::default(), } } @@ -62,7 +58,7 @@ impl Bandwidth { if self.stats_tcp_v4.is_empty() { return None; } - stats_map = core::mem::replace(&mut self.stats_tcp_v4, DeviceHashMap::new()); + stats_map = core::mem::replace(&mut self.stats_tcp_v4, BTreeMap::new()); } let mut values = alloc::vec::Vec::with_capacity(stats_map.len()); @@ -89,7 +85,7 @@ impl Bandwidth { if self.stats_tcp_v6.is_empty() { return None; } - stats_map = core::mem::replace(&mut self.stats_tcp_v6, DeviceHashMap::new()); + stats_map = core::mem::replace(&mut self.stats_tcp_v6, BTreeMap::new()); } let mut values = alloc::vec::Vec::with_capacity(stats_map.len()); @@ -116,7 +112,7 @@ impl Bandwidth { if self.stats_udp_v4.is_empty() { return None; } - stats_map = core::mem::replace(&mut self.stats_udp_v4, DeviceHashMap::new()); + stats_map = core::mem::replace(&mut self.stats_udp_v4, BTreeMap::new()); } let mut values = alloc::vec::Vec::with_capacity(stats_map.len()); @@ -140,10 +136,10 @@ impl Bandwidth { let stats_map; { let _guard = self.stats_udp_v6_lock.write_lock(); - if self.stats_tcp_v6.is_empty() { + if self.stats_udp_v6.is_empty() { return None; } - stats_map = core::mem::replace(&mut self.stats_tcp_v6, DeviceHashMap::new()); + stats_map = core::mem::replace(&mut self.stats_udp_v6, BTreeMap::new()); } let mut values = alloc::vec::Vec::with_capacity(stats_map.len()); @@ -235,8 +231,8 @@ impl Bandwidth { ); } - fn update( - map: &mut DeviceHashMap, Value>, + fn update( + map: &mut BTreeMap, Value>, lock: &mut RwSpinLock, key: Key
, bytes: Direction, diff --git a/windows_kext/driver/src/connection_cache.rs b/windows_kext/driver/src/connection_cache.rs index 665e60f41..a076df022 100644 --- a/windows_kext/driver/src/connection_cache.rs +++ b/windows_kext/driver/src/connection_cache.rs @@ -1,10 +1,8 @@ -use core::time::Duration; - use crate::{ connection::{Connection, ConnectionV4, ConnectionV6, RedirectInfo, Verdict}, connection_map::{ConnectionMap, Key}, }; -use alloc::{format, string::String, vec::Vec}; +use alloc::vec::Vec; use smoltcp::wire::IpProtocol; use wdk::rw_spin_lock::RwSpinLock; @@ -128,73 +126,4 @@ impl ConnectionCache { return size; } - - #[allow(dead_code)] - pub fn get_full_cache_info(&self) -> String { - let mut info = String::new(); - let now = wdk::utils::get_system_timestamp_ms(); - { - let _guard = self.lock_v4.read_lock(); - for ((protocol, port), connections) in self.connections_v4.iter() { - info.push_str(&format!("{} -> {}\n", protocol, port,)); - for conn in connections { - let active_time_seconds = - Duration::from_millis(now - conn.get_last_accessed_time()).as_secs(); - info.push_str(&format!( - "\t{}:{} -> {}:{} {} last active {}m {}s ago", - conn.local_address, - conn.local_port, - conn.remote_address, - conn.remote_port, - conn.verdict, - active_time_seconds / 60, - active_time_seconds % 60 - )); - if conn.has_ended() { - let end_time_seconds = - Duration::from_millis(now - conn.get_end_time()).as_secs(); - info.push_str(&format!( - "\t ended {}m {}s ago", - end_time_seconds / 60, - end_time_seconds % 60 - )); - } - info.push('\n'); - } - } - } - - { - let _guard = self.lock_v6.read_lock(); - for ((protocol, port), connections) in self.connections_v6.iter() { - info.push_str(&format!("{} -> {} \n", protocol, port)); - for conn in connections { - let active_time_seconds = - Duration::from_millis(now - conn.get_last_accessed_time()).as_secs(); - info.push_str(&format!( - "\t{}:{} -> {}:{} {} last active {}m {}s ago", - conn.local_address, - conn.local_port, - conn.remote_address, - conn.remote_port, - conn.verdict, - active_time_seconds / 60, - active_time_seconds % 60 - )); - if conn.has_ended() { - let end_time_seconds = - Duration::from_millis(now - conn.get_end_time()).as_secs(); - info.push_str(&format!( - "\t ended {}m {}s ago", - end_time_seconds / 60, - end_time_seconds % 60 - )); - } - info.push('\n'); - } - } - } - - return info; - } } diff --git a/windows_kext/driver/src/connection_map.rs b/windows_kext/driver/src/connection_map.rs index bf2210f8b..7cadf87be 100644 --- a/windows_kext/driver/src/connection_map.rs +++ b/windows_kext/driver/src/connection_map.rs @@ -1,8 +1,7 @@ use core::{fmt::Display, time::Duration}; use crate::connection::Connection; -use alloc::vec::Vec; -use hashbrown::HashMap; +use alloc::{collections::BTreeMap, vec::Vec}; use smoltcp::wire::{IpAddress, IpProtocol}; #[derive(Clone, Copy, PartialEq, PartialOrd, Eq, Ord)] @@ -63,11 +62,11 @@ impl Key { } } -pub struct ConnectionMap(HashMap<(IpProtocol, u16), Vec>); +pub struct ConnectionMap(BTreeMap<(IpProtocol, u16), Vec>); impl ConnectionMap { pub fn new() -> Self { - Self(HashMap::new()) + Self(BTreeMap::new()) } pub fn add(&mut self, conn: T) { @@ -164,7 +163,6 @@ impl ConnectionMap { self.0.retain(|_, v| !v.is_empty()); } - #[allow(dead_code)] pub fn get_count(&self) -> usize { let mut count = 0; for conn in self.0.values() { @@ -172,8 +170,4 @@ impl ConnectionMap { } return count; } - - pub fn iter(&self) -> hashbrown::hash_map::Iter<'_, (IpProtocol, u16), Vec> { - self.0.iter() - } } diff --git a/windows_kext/driver/src/driver_hashmap.rs b/windows_kext/driver/src/driver_hashmap.rs deleted file mode 100644 index 1c8b706ab..000000000 --- a/windows_kext/driver/src/driver_hashmap.rs +++ /dev/null @@ -1,25 +0,0 @@ -use core::ops::{Deref, DerefMut}; - -use hashbrown::HashMap; - -pub struct DeviceHashMap(Option>); - -impl DeviceHashMap { - pub fn new() -> Self { - Self(Some(HashMap::new())) - } -} - -impl Deref for DeviceHashMap { - type Target = HashMap; - - fn deref(&self) -> &Self::Target { - self.0.as_ref().unwrap() - } -} - -impl DerefMut for DeviceHashMap { - fn deref_mut(&mut self) -> &mut Self::Target { - self.0.as_mut().unwrap() - } -} diff --git a/windows_kext/driver/src/lib.rs b/windows_kext/driver/src/lib.rs index d13e9d3f3..7d9fe3a1d 100644 --- a/windows_kext/driver/src/lib.rs +++ b/windows_kext/driver/src/lib.rs @@ -13,7 +13,6 @@ mod connection; mod connection_cache; mod connection_map; mod device; -mod driver_hashmap; mod entry; mod id_cache; pub mod logger; diff --git a/windows_kext/driver/src/stream_callouts.rs b/windows_kext/driver/src/stream_callouts.rs index a63937643..f0a0f0d04 100644 --- a/windows_kext/driver/src/stream_callouts.rs +++ b/windows_kext/driver/src/stream_callouts.rs @@ -4,6 +4,8 @@ use wdk::filter_engine::{callout_data::CalloutData, layer, net_buffer::NetBuffer use crate::{bandwidth, connection::Direction}; pub fn stream_layer_tcp_v4(data: CalloutData) { + type Fields = layer::FieldsStreamV4; + let Some(device) = crate::entry::get_device() else { return; }; @@ -16,7 +18,6 @@ pub fn stream_layer_tcp_v4(data: CalloutData) { } else { return; }; - type Fields = layer::FieldsStreamV4; let local_ip = Ipv4Address::from_bytes( &data .get_value_u32(Fields::IpLocalAddress as usize) @@ -56,6 +57,8 @@ pub fn stream_layer_tcp_v4(data: CalloutData) { } pub fn stream_layer_tcp_v6(data: CalloutData) { + type Fields = layer::FieldsStreamV6; + let Some(device) = crate::entry::get_device() else { return; }; @@ -68,16 +71,18 @@ pub fn stream_layer_tcp_v6(data: CalloutData) { } else { return; }; - type Fields = layer::FieldsStreamV6; + if data_length == 0 { return; } let local_ip = Ipv6Address::from_bytes(data.get_value_byte_array16(Fields::IpLocalAddress as usize)); let local_port = data.get_value_u16(Fields::IpLocalPort as usize); + let remote_ip = Ipv6Address::from_bytes(data.get_value_byte_array16(Fields::IpRemoteAddress as usize)); let remote_port = data.get_value_u16(Fields::IpRemotePort as usize); + match direction { Direction::Outbound => { device.bandwidth_stats.update_tcp_v6_tx( @@ -105,6 +110,8 @@ pub fn stream_layer_tcp_v6(data: CalloutData) { } pub fn stream_layer_udp_v4(data: CalloutData) { + type Fields = layer::FieldsDatagramDataV4; + let Some(device) = crate::entry::get_device() else { return; }; @@ -112,7 +119,6 @@ pub fn stream_layer_udp_v4(data: CalloutData) { for nbl in NetBufferListIter::new(data.get_layer_data() as _) { data_length += nbl.get_data_length() as usize; } - type Fields = layer::FieldsDatagramDataV4; let mut direction = Direction::Inbound; if data.get_value_u8(Fields::Direction as usize) == 0 { direction = Direction::Outbound; @@ -157,6 +163,8 @@ pub fn stream_layer_udp_v4(data: CalloutData) { } pub fn stream_layer_udp_v6(data: CalloutData) { + type Fields = layer::FieldsDatagramDataV6; + let Some(device) = crate::entry::get_device() else { return; }; @@ -164,7 +172,6 @@ pub fn stream_layer_udp_v6(data: CalloutData) { for nbl in NetBufferListIter::new(data.get_layer_data() as _) { data_length += nbl.get_data_length() as usize; } - type Fields = layer::FieldsDatagramDataV6; let mut direction = Direction::Inbound; if data.get_value_u8(Fields::Direction as usize) == 0 { direction = Direction::Outbound; diff --git a/windows_kext/kextinterface/kext.go b/windows_kext/kextinterface/kext.go index 8322ead8f..c9b61c7c8 100644 --- a/windows_kext/kextinterface/kext.go +++ b/windows_kext/kextinterface/kext.go @@ -38,7 +38,7 @@ var ( ) const ( - winInvalidHandleValue = windows.Handle(^uintptr(0)) // Max value + winInvalidHandleValue = windows.InvalidHandle stopServiceTimeoutDuration = time.Duration(30 * time.Second) ) @@ -48,7 +48,7 @@ type KextService struct { } func (s *KextService) isValid() bool { - return s != nil && s.handle != winInvalidHandleValue && s.handle != 0 + return s != nil && s.handle != windows.InvalidHandle && s.handle != 0 } func (s *KextService) isRunning() (bool, error) { @@ -99,7 +99,7 @@ func (s *KextService) Start(wait bool) error { _ = windows.ControlService(s.handle, windows.SERVICE_CONTROL_STOP, &status) _ = windows.DeleteService(s.handle) _ = windows.CloseServiceHandle(s.handle) - s.handle = winInvalidHandleValue + s.handle = windows.InvalidHandle return err } } @@ -158,7 +158,7 @@ func (s *KextService) Delete() error { return fmt.Errorf("failed to close service handle: %s", err) } - s.handle = winInvalidHandleValue + s.handle = windows.InvalidHandle return nil } @@ -234,7 +234,7 @@ func CreateKextService(driverName string, driverPath string) (*KextService, erro return nil, err } - service = winInvalidHandleValue + service = windows.InvalidHandle log.Warning("kext: old driver service was deleted successfully") } diff --git a/windows_kext/kextinterface/kext_file.go b/windows_kext/kextinterface/kext_file.go index 045ee06e5..fac6c5cdd 100644 --- a/windows_kext/kextinterface/kext_file.go +++ b/windows_kext/kextinterface/kext_file.go @@ -4,6 +4,8 @@ package kextinterface import ( + "fmt" + "golang.org/x/sys/windows" ) @@ -13,7 +15,16 @@ type KextFile struct { read_slice []byte } +// Read tries to read the supplied buffer length from the driver. +// The data from the driver is read in chunks `len(f.buffer)` and the extra data is cached for the next call. +// The performance penalty of calling the function with small buffers is very small. +// The function will block until the next info packet is received from the kext. func (f *KextFile) Read(buffer []byte) (int, error) { + if err := f.IsValid(); err != nil { + return 0, fmt.Errorf("failed to read: %w", err) + } + + // If no data is available from previous calls, read from kext. if f.read_slice == nil || len(f.read_slice) == 0 { err := f.refill_read_buffer() if err != nil { @@ -22,14 +33,19 @@ func (f *KextFile) Read(buffer []byte) (int, error) { } if len(f.read_slice) >= len(buffer) { - // Write all requested bytes. + // There is enough data to fill the requested buffer. copy(buffer, f.read_slice[0:len(buffer)]) + // Move the slice to contain the remaining data. f.read_slice = f.read_slice[len(buffer):] } else { - // Write all available bytes and read again. + // There is not enough data to fill the requested buffer. + + // Write everything available. copy(buffer[0:len(f.read_slice)], f.read_slice) copiedBytes := len(f.read_slice) f.read_slice = nil + + // Read again. _, err := f.Read(buffer[copiedBytes:]) if err != nil { return 0, err @@ -51,20 +67,33 @@ func (f *KextFile) refill_read_buffer() error { return nil } +// Write sends the buffer bytes to the kext. The function will block until the whole buffer is written to the kext. func (f *KextFile) Write(buffer []byte) (int, error) { + if err := f.IsValid(); err != nil { + return 0, fmt.Errorf("failed to write: %w", err) + } var count uint32 = 0 overlapped := &windows.Overlapped{} err := windows.WriteFile(f.handle, buffer, &count, overlapped) return int(count), err } +// Close closes the handle to the kext. This will cancel all active Reads and Writes. func (f *KextFile) Close() error { + if err := f.IsValid(); err != nil { + return fmt.Errorf("failed to close: %w", err) + } err := windows.CloseHandle(f.handle) - f.handle = winInvalidHandleValue + f.handle = windows.InvalidHandle return err } +// deviceIOControl exists for compatibility with the old kext. func (f *KextFile) deviceIOControl(code uint32, inData []byte, outData []byte) (*windows.Overlapped, error) { + if err := f.IsValid(); err != nil { + return nil, fmt.Errorf("failed to send io control: %w", err) + } + // Prepare the input data var inDataPtr *byte = nil var inDataSize uint32 = 0 if inData != nil { @@ -72,6 +101,7 @@ func (f *KextFile) deviceIOControl(code uint32, inData []byte, outData []byte) ( inDataSize = uint32(len(inData)) } + // Prepare the output data var outDataPtr *byte = nil var outDataSize uint32 = 0 if outData != nil { @@ -79,6 +109,7 @@ func (f *KextFile) deviceIOControl(code uint32, inData []byte, outData []byte) ( outDataSize = uint32(len(outData)) } + // Make the request to the kext. overlapped := &windows.Overlapped{} err := windows.DeviceIoControl(f.handle, code, @@ -92,6 +123,20 @@ func (f *KextFile) deviceIOControl(code uint32, inData []byte, outData []byte) ( return overlapped, nil } +// GetHandle returns the handle of the kext. func (f *KextFile) GetHandle() windows.Handle { return f.handle } + +// IsValid checks if kext file holds a valid handle to the kext driver. +func (f *KextFile) IsValid() error { + if f == nil { + return fmt.Errorf("nil kext file") + } + + if f.handle == windows.Handle(0) || f.handle == windows.InvalidHandle { + return fmt.Errorf("invalid handle") + } + + return nil +}