diff --git a/windows_kext/driver/Cargo.lock b/windows_kext/driver/Cargo.lock
index b87467456..4a0e0b397 100644
--- a/windows_kext/driver/Cargo.lock
+++ b/windows_kext/driver/Cargo.lock
@@ -2,18 +2,6 @@
# It is not intended for manual editing.
version = 3
-[[package]]
-name = "ahash"
-version = "0.8.6"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "91429305e9f0a25f6205c5b8e0d2db09e0708a7a6df0f42212bb56c32c8ac97a"
-dependencies = [
- "cfg-if",
- "once_cell",
- "version_check",
- "zerocopy",
-]
-
[[package]]
name = "atomic-polyfill"
version = "1.0.3"
@@ -57,7 +45,6 @@ checksum = "7059fff8937831a9ae6f0fe4d658ffabf58f2ca96aa9dec1c889f936f705f216"
name = "driver"
version = "0.0.0"
dependencies = [
- "hashbrown",
"num",
"num-derive",
"num-traits",
@@ -76,15 +63,6 @@ dependencies = [
"byteorder",
]
-[[package]]
-name = "hashbrown"
-version = "0.14.3"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604"
-dependencies = [
- "ahash",
-]
-
[[package]]
name = "heapless"
version = "0.7.17"
@@ -217,12 +195,6 @@ dependencies = [
"syn",
]
-[[package]]
-name = "once_cell"
-version = "1.19.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92"
-
[[package]]
name = "proc-macro2"
version = "1.0.78"
@@ -316,12 +288,6 @@ version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b"
-[[package]]
-name = "version_check"
-version = "0.9.4"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
-
[[package]]
name = "wdk"
version = "0.0.0"
@@ -399,23 +365,3 @@ source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e
name = "windows_x86_64_msvc"
version = "0.52.5"
source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317"
-
-[[package]]
-name = "zerocopy"
-version = "0.7.28"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7d6f15f7ade05d2a4935e34a457b936c23dc70a05cc1d97133dc99e7a3fe0f0e"
-dependencies = [
- "zerocopy-derive",
-]
-
-[[package]]
-name = "zerocopy-derive"
-version = "0.7.28"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "dbbad221e3f78500350ecbd7dfa4e63ef945c05f4c61cb7f4d3f84cd0bba649b"
-dependencies = [
- "proc-macro2",
- "quote",
- "syn",
-]
diff --git a/windows_kext/driver/Cargo.toml b/windows_kext/driver/Cargo.toml
index 66dffacab..034627c0b 100644
--- a/windows_kext/driver/Cargo.toml
+++ b/windows_kext/driver/Cargo.toml
@@ -17,7 +17,6 @@ num = { version = "0.4", default-features = false }
num-derive = { version = "0.4", default-features = false }
num-traits = { version = "0.2", default-features = false }
smoltcp = { version = "0.10", default-features = false, features = ["proto-ipv4", "proto-ipv6"] }
-hashbrown = { version = "0.14.3", default-features = false, features = ["ahash"]}
# WARNING: Do not update. The version was choosen for a reason. See wdk/README.md for more detiels.
[dependencies.windows-sys]
diff --git a/windows_kext/driver/src/bandwidth.rs b/windows_kext/driver/src/bandwidth.rs
index 4fb487867..0105ac72e 100644
--- a/windows_kext/driver/src/bandwidth.rs
+++ b/windows_kext/driver/src/bandwidth.rs
@@ -1,14 +1,10 @@
+use alloc::collections::BTreeMap;
use protocol::info::{BandwidthValueV4, BandwidthValueV6, Info};
use smoltcp::wire::{IpProtocol, Ipv4Address, Ipv6Address};
use wdk::rw_spin_lock::RwSpinLock;
-use crate::driver_hashmap::DeviceHashMap;
-
-#[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Default)]
-pub struct Key
-where
- Address: Eq + PartialEq,
-{
+#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Default)]
+pub struct Key {
pub local_ip: Address,
pub local_port: u16,
pub remote_ip: Address,
@@ -25,32 +21,32 @@ enum Direction {
Rx(usize),
}
pub struct Bandwidth {
- stats_tcp_v4: DeviceHashMap, Value>,
+ stats_tcp_v4: BTreeMap, Value>,
stats_tcp_v4_lock: RwSpinLock,
- stats_tcp_v6: DeviceHashMap, Value>,
+ stats_tcp_v6: BTreeMap, Value>,
stats_tcp_v6_lock: RwSpinLock,
- stats_udp_v4: DeviceHashMap, Value>,
+ stats_udp_v4: BTreeMap, Value>,
stats_udp_v4_lock: RwSpinLock,
- stats_udp_v6: DeviceHashMap, Value>,
+ stats_udp_v6: BTreeMap, Value>,
stats_udp_v6_lock: RwSpinLock,
}
impl Bandwidth {
pub fn new() -> Self {
Self {
- stats_tcp_v4: DeviceHashMap::new(),
+ stats_tcp_v4: BTreeMap::new(),
stats_tcp_v4_lock: RwSpinLock::default(),
- stats_tcp_v6: DeviceHashMap::new(),
+ stats_tcp_v6: BTreeMap::new(),
stats_tcp_v6_lock: RwSpinLock::default(),
- stats_udp_v4: DeviceHashMap::new(),
+ stats_udp_v4: BTreeMap::new(),
stats_udp_v4_lock: RwSpinLock::default(),
- stats_udp_v6: DeviceHashMap::new(),
+ stats_udp_v6: BTreeMap::new(),
stats_udp_v6_lock: RwSpinLock::default(),
}
}
@@ -62,7 +58,7 @@ impl Bandwidth {
if self.stats_tcp_v4.is_empty() {
return None;
}
- stats_map = core::mem::replace(&mut self.stats_tcp_v4, DeviceHashMap::new());
+ stats_map = core::mem::replace(&mut self.stats_tcp_v4, BTreeMap::new());
}
let mut values = alloc::vec::Vec::with_capacity(stats_map.len());
@@ -89,7 +85,7 @@ impl Bandwidth {
if self.stats_tcp_v6.is_empty() {
return None;
}
- stats_map = core::mem::replace(&mut self.stats_tcp_v6, DeviceHashMap::new());
+ stats_map = core::mem::replace(&mut self.stats_tcp_v6, BTreeMap::new());
}
let mut values = alloc::vec::Vec::with_capacity(stats_map.len());
@@ -116,7 +112,7 @@ impl Bandwidth {
if self.stats_udp_v4.is_empty() {
return None;
}
- stats_map = core::mem::replace(&mut self.stats_udp_v4, DeviceHashMap::new());
+ stats_map = core::mem::replace(&mut self.stats_udp_v4, BTreeMap::new());
}
let mut values = alloc::vec::Vec::with_capacity(stats_map.len());
@@ -140,10 +136,10 @@ impl Bandwidth {
let stats_map;
{
let _guard = self.stats_udp_v6_lock.write_lock();
- if self.stats_tcp_v6.is_empty() {
+ if self.stats_udp_v6.is_empty() {
return None;
}
- stats_map = core::mem::replace(&mut self.stats_tcp_v6, DeviceHashMap::new());
+ stats_map = core::mem::replace(&mut self.stats_udp_v6, BTreeMap::new());
}
let mut values = alloc::vec::Vec::with_capacity(stats_map.len());
@@ -235,8 +231,8 @@ impl Bandwidth {
);
}
- fn update(
- map: &mut DeviceHashMap, Value>,
+ fn update(
+ map: &mut BTreeMap, Value>,
lock: &mut RwSpinLock,
key: Key,
bytes: Direction,
diff --git a/windows_kext/driver/src/connection_cache.rs b/windows_kext/driver/src/connection_cache.rs
index 665e60f41..a076df022 100644
--- a/windows_kext/driver/src/connection_cache.rs
+++ b/windows_kext/driver/src/connection_cache.rs
@@ -1,10 +1,8 @@
-use core::time::Duration;
-
use crate::{
connection::{Connection, ConnectionV4, ConnectionV6, RedirectInfo, Verdict},
connection_map::{ConnectionMap, Key},
};
-use alloc::{format, string::String, vec::Vec};
+use alloc::vec::Vec;
use smoltcp::wire::IpProtocol;
use wdk::rw_spin_lock::RwSpinLock;
@@ -128,73 +126,4 @@ impl ConnectionCache {
return size;
}
-
- #[allow(dead_code)]
- pub fn get_full_cache_info(&self) -> String {
- let mut info = String::new();
- let now = wdk::utils::get_system_timestamp_ms();
- {
- let _guard = self.lock_v4.read_lock();
- for ((protocol, port), connections) in self.connections_v4.iter() {
- info.push_str(&format!("{} -> {}\n", protocol, port,));
- for conn in connections {
- let active_time_seconds =
- Duration::from_millis(now - conn.get_last_accessed_time()).as_secs();
- info.push_str(&format!(
- "\t{}:{} -> {}:{} {} last active {}m {}s ago",
- conn.local_address,
- conn.local_port,
- conn.remote_address,
- conn.remote_port,
- conn.verdict,
- active_time_seconds / 60,
- active_time_seconds % 60
- ));
- if conn.has_ended() {
- let end_time_seconds =
- Duration::from_millis(now - conn.get_end_time()).as_secs();
- info.push_str(&format!(
- "\t ended {}m {}s ago",
- end_time_seconds / 60,
- end_time_seconds % 60
- ));
- }
- info.push('\n');
- }
- }
- }
-
- {
- let _guard = self.lock_v6.read_lock();
- for ((protocol, port), connections) in self.connections_v6.iter() {
- info.push_str(&format!("{} -> {} \n", protocol, port));
- for conn in connections {
- let active_time_seconds =
- Duration::from_millis(now - conn.get_last_accessed_time()).as_secs();
- info.push_str(&format!(
- "\t{}:{} -> {}:{} {} last active {}m {}s ago",
- conn.local_address,
- conn.local_port,
- conn.remote_address,
- conn.remote_port,
- conn.verdict,
- active_time_seconds / 60,
- active_time_seconds % 60
- ));
- if conn.has_ended() {
- let end_time_seconds =
- Duration::from_millis(now - conn.get_end_time()).as_secs();
- info.push_str(&format!(
- "\t ended {}m {}s ago",
- end_time_seconds / 60,
- end_time_seconds % 60
- ));
- }
- info.push('\n');
- }
- }
- }
-
- return info;
- }
}
diff --git a/windows_kext/driver/src/connection_map.rs b/windows_kext/driver/src/connection_map.rs
index bf2210f8b..7cadf87be 100644
--- a/windows_kext/driver/src/connection_map.rs
+++ b/windows_kext/driver/src/connection_map.rs
@@ -1,8 +1,7 @@
use core::{fmt::Display, time::Duration};
use crate::connection::Connection;
-use alloc::vec::Vec;
-use hashbrown::HashMap;
+use alloc::{collections::BTreeMap, vec::Vec};
use smoltcp::wire::{IpAddress, IpProtocol};
#[derive(Clone, Copy, PartialEq, PartialOrd, Eq, Ord)]
@@ -63,11 +62,11 @@ impl Key {
}
}
-pub struct ConnectionMap(HashMap<(IpProtocol, u16), Vec>);
+pub struct ConnectionMap(BTreeMap<(IpProtocol, u16), Vec>);
impl ConnectionMap {
pub fn new() -> Self {
- Self(HashMap::new())
+ Self(BTreeMap::new())
}
pub fn add(&mut self, conn: T) {
@@ -164,7 +163,6 @@ impl ConnectionMap {
self.0.retain(|_, v| !v.is_empty());
}
- #[allow(dead_code)]
pub fn get_count(&self) -> usize {
let mut count = 0;
for conn in self.0.values() {
@@ -172,8 +170,4 @@ impl ConnectionMap {
}
return count;
}
-
- pub fn iter(&self) -> hashbrown::hash_map::Iter<'_, (IpProtocol, u16), Vec> {
- self.0.iter()
- }
}
diff --git a/windows_kext/driver/src/driver_hashmap.rs b/windows_kext/driver/src/driver_hashmap.rs
deleted file mode 100644
index 1c8b706ab..000000000
--- a/windows_kext/driver/src/driver_hashmap.rs
+++ /dev/null
@@ -1,25 +0,0 @@
-use core::ops::{Deref, DerefMut};
-
-use hashbrown::HashMap;
-
-pub struct DeviceHashMap(Option>);
-
-impl DeviceHashMap {
- pub fn new() -> Self {
- Self(Some(HashMap::new()))
- }
-}
-
-impl Deref for DeviceHashMap {
- type Target = HashMap;
-
- fn deref(&self) -> &Self::Target {
- self.0.as_ref().unwrap()
- }
-}
-
-impl DerefMut for DeviceHashMap {
- fn deref_mut(&mut self) -> &mut Self::Target {
- self.0.as_mut().unwrap()
- }
-}
diff --git a/windows_kext/driver/src/lib.rs b/windows_kext/driver/src/lib.rs
index d13e9d3f3..7d9fe3a1d 100644
--- a/windows_kext/driver/src/lib.rs
+++ b/windows_kext/driver/src/lib.rs
@@ -13,7 +13,6 @@ mod connection;
mod connection_cache;
mod connection_map;
mod device;
-mod driver_hashmap;
mod entry;
mod id_cache;
pub mod logger;
diff --git a/windows_kext/driver/src/stream_callouts.rs b/windows_kext/driver/src/stream_callouts.rs
index a63937643..f0a0f0d04 100644
--- a/windows_kext/driver/src/stream_callouts.rs
+++ b/windows_kext/driver/src/stream_callouts.rs
@@ -4,6 +4,8 @@ use wdk::filter_engine::{callout_data::CalloutData, layer, net_buffer::NetBuffer
use crate::{bandwidth, connection::Direction};
pub fn stream_layer_tcp_v4(data: CalloutData) {
+ type Fields = layer::FieldsStreamV4;
+
let Some(device) = crate::entry::get_device() else {
return;
};
@@ -16,7 +18,6 @@ pub fn stream_layer_tcp_v4(data: CalloutData) {
} else {
return;
};
- type Fields = layer::FieldsStreamV4;
let local_ip = Ipv4Address::from_bytes(
&data
.get_value_u32(Fields::IpLocalAddress as usize)
@@ -56,6 +57,8 @@ pub fn stream_layer_tcp_v4(data: CalloutData) {
}
pub fn stream_layer_tcp_v6(data: CalloutData) {
+ type Fields = layer::FieldsStreamV6;
+
let Some(device) = crate::entry::get_device() else {
return;
};
@@ -68,16 +71,18 @@ pub fn stream_layer_tcp_v6(data: CalloutData) {
} else {
return;
};
- type Fields = layer::FieldsStreamV6;
+
if data_length == 0 {
return;
}
let local_ip =
Ipv6Address::from_bytes(data.get_value_byte_array16(Fields::IpLocalAddress as usize));
let local_port = data.get_value_u16(Fields::IpLocalPort as usize);
+
let remote_ip =
Ipv6Address::from_bytes(data.get_value_byte_array16(Fields::IpRemoteAddress as usize));
let remote_port = data.get_value_u16(Fields::IpRemotePort as usize);
+
match direction {
Direction::Outbound => {
device.bandwidth_stats.update_tcp_v6_tx(
@@ -105,6 +110,8 @@ pub fn stream_layer_tcp_v6(data: CalloutData) {
}
pub fn stream_layer_udp_v4(data: CalloutData) {
+ type Fields = layer::FieldsDatagramDataV4;
+
let Some(device) = crate::entry::get_device() else {
return;
};
@@ -112,7 +119,6 @@ pub fn stream_layer_udp_v4(data: CalloutData) {
for nbl in NetBufferListIter::new(data.get_layer_data() as _) {
data_length += nbl.get_data_length() as usize;
}
- type Fields = layer::FieldsDatagramDataV4;
let mut direction = Direction::Inbound;
if data.get_value_u8(Fields::Direction as usize) == 0 {
direction = Direction::Outbound;
@@ -157,6 +163,8 @@ pub fn stream_layer_udp_v4(data: CalloutData) {
}
pub fn stream_layer_udp_v6(data: CalloutData) {
+ type Fields = layer::FieldsDatagramDataV6;
+
let Some(device) = crate::entry::get_device() else {
return;
};
@@ -164,7 +172,6 @@ pub fn stream_layer_udp_v6(data: CalloutData) {
for nbl in NetBufferListIter::new(data.get_layer_data() as _) {
data_length += nbl.get_data_length() as usize;
}
- type Fields = layer::FieldsDatagramDataV6;
let mut direction = Direction::Inbound;
if data.get_value_u8(Fields::Direction as usize) == 0 {
direction = Direction::Outbound;
diff --git a/windows_kext/kextinterface/kext.go b/windows_kext/kextinterface/kext.go
index 8322ead8f..c9b61c7c8 100644
--- a/windows_kext/kextinterface/kext.go
+++ b/windows_kext/kextinterface/kext.go
@@ -38,7 +38,7 @@ var (
)
const (
- winInvalidHandleValue = windows.Handle(^uintptr(0)) // Max value
+ winInvalidHandleValue = windows.InvalidHandle
stopServiceTimeoutDuration = time.Duration(30 * time.Second)
)
@@ -48,7 +48,7 @@ type KextService struct {
}
func (s *KextService) isValid() bool {
- return s != nil && s.handle != winInvalidHandleValue && s.handle != 0
+ return s != nil && s.handle != windows.InvalidHandle && s.handle != 0
}
func (s *KextService) isRunning() (bool, error) {
@@ -99,7 +99,7 @@ func (s *KextService) Start(wait bool) error {
_ = windows.ControlService(s.handle, windows.SERVICE_CONTROL_STOP, &status)
_ = windows.DeleteService(s.handle)
_ = windows.CloseServiceHandle(s.handle)
- s.handle = winInvalidHandleValue
+ s.handle = windows.InvalidHandle
return err
}
}
@@ -158,7 +158,7 @@ func (s *KextService) Delete() error {
return fmt.Errorf("failed to close service handle: %s", err)
}
- s.handle = winInvalidHandleValue
+ s.handle = windows.InvalidHandle
return nil
}
@@ -234,7 +234,7 @@ func CreateKextService(driverName string, driverPath string) (*KextService, erro
return nil, err
}
- service = winInvalidHandleValue
+ service = windows.InvalidHandle
log.Warning("kext: old driver service was deleted successfully")
}
diff --git a/windows_kext/kextinterface/kext_file.go b/windows_kext/kextinterface/kext_file.go
index 045ee06e5..fac6c5cdd 100644
--- a/windows_kext/kextinterface/kext_file.go
+++ b/windows_kext/kextinterface/kext_file.go
@@ -4,6 +4,8 @@
package kextinterface
import (
+ "fmt"
+
"golang.org/x/sys/windows"
)
@@ -13,7 +15,16 @@ type KextFile struct {
read_slice []byte
}
+// Read tries to read the supplied buffer length from the driver.
+// The data from the driver is read in chunks `len(f.buffer)` and the extra data is cached for the next call.
+// The performance penalty of calling the function with small buffers is very small.
+// The function will block until the next info packet is received from the kext.
func (f *KextFile) Read(buffer []byte) (int, error) {
+ if err := f.IsValid(); err != nil {
+ return 0, fmt.Errorf("failed to read: %w", err)
+ }
+
+ // If no data is available from previous calls, read from kext.
if f.read_slice == nil || len(f.read_slice) == 0 {
err := f.refill_read_buffer()
if err != nil {
@@ -22,14 +33,19 @@ func (f *KextFile) Read(buffer []byte) (int, error) {
}
if len(f.read_slice) >= len(buffer) {
- // Write all requested bytes.
+ // There is enough data to fill the requested buffer.
copy(buffer, f.read_slice[0:len(buffer)])
+ // Move the slice to contain the remaining data.
f.read_slice = f.read_slice[len(buffer):]
} else {
- // Write all available bytes and read again.
+ // There is not enough data to fill the requested buffer.
+
+ // Write everything available.
copy(buffer[0:len(f.read_slice)], f.read_slice)
copiedBytes := len(f.read_slice)
f.read_slice = nil
+
+ // Read again.
_, err := f.Read(buffer[copiedBytes:])
if err != nil {
return 0, err
@@ -51,20 +67,33 @@ func (f *KextFile) refill_read_buffer() error {
return nil
}
+// Write sends the buffer bytes to the kext. The function will block until the whole buffer is written to the kext.
func (f *KextFile) Write(buffer []byte) (int, error) {
+ if err := f.IsValid(); err != nil {
+ return 0, fmt.Errorf("failed to write: %w", err)
+ }
var count uint32 = 0
overlapped := &windows.Overlapped{}
err := windows.WriteFile(f.handle, buffer, &count, overlapped)
return int(count), err
}
+// Close closes the handle to the kext. This will cancel all active Reads and Writes.
func (f *KextFile) Close() error {
+ if err := f.IsValid(); err != nil {
+ return fmt.Errorf("failed to close: %w", err)
+ }
err := windows.CloseHandle(f.handle)
- f.handle = winInvalidHandleValue
+ f.handle = windows.InvalidHandle
return err
}
+// deviceIOControl exists for compatibility with the old kext.
func (f *KextFile) deviceIOControl(code uint32, inData []byte, outData []byte) (*windows.Overlapped, error) {
+ if err := f.IsValid(); err != nil {
+ return nil, fmt.Errorf("failed to send io control: %w", err)
+ }
+ // Prepare the input data
var inDataPtr *byte = nil
var inDataSize uint32 = 0
if inData != nil {
@@ -72,6 +101,7 @@ func (f *KextFile) deviceIOControl(code uint32, inData []byte, outData []byte) (
inDataSize = uint32(len(inData))
}
+ // Prepare the output data
var outDataPtr *byte = nil
var outDataSize uint32 = 0
if outData != nil {
@@ -79,6 +109,7 @@ func (f *KextFile) deviceIOControl(code uint32, inData []byte, outData []byte) (
outDataSize = uint32(len(outData))
}
+ // Make the request to the kext.
overlapped := &windows.Overlapped{}
err := windows.DeviceIoControl(f.handle,
code,
@@ -92,6 +123,20 @@ func (f *KextFile) deviceIOControl(code uint32, inData []byte, outData []byte) (
return overlapped, nil
}
+// GetHandle returns the handle of the kext.
func (f *KextFile) GetHandle() windows.Handle {
return f.handle
}
+
+// IsValid checks if kext file holds a valid handle to the kext driver.
+func (f *KextFile) IsValid() error {
+ if f == nil {
+ return fmt.Errorf("nil kext file")
+ }
+
+ if f.handle == windows.Handle(0) || f.handle == windows.InvalidHandle {
+ return fmt.Errorf("invalid handle")
+ }
+
+ return nil
+}