Skip to content

Commit

Permalink
[service] Fix check for invalid kext handle (#1716)
Browse files Browse the repository at this point in the history
* [service] Fix check for invalid kext handle

* [windows_kext] Use BTreeMap as cache structure

* [windows_kext] Fix synchronization bug

* Update windows_kext/kextinterface/kext_file.go

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* Update windows_kext/kextinterface/kext_file.go

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* Update windows_kext/kextinterface/kext_file.go

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

---------

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
  • Loading branch information
vlabo and coderabbitai[bot] authored Oct 16, 2024
1 parent cfd8777 commit 355f743
Show file tree
Hide file tree
Showing 10 changed files with 86 additions and 196 deletions.
54 changes: 0 additions & 54 deletions windows_kext/driver/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion windows_kext/driver/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ num = { version = "0.4", default-features = false }
num-derive = { version = "0.4", default-features = false }
num-traits = { version = "0.2", default-features = false }
smoltcp = { version = "0.10", default-features = false, features = ["proto-ipv4", "proto-ipv6"] }
hashbrown = { version = "0.14.3", default-features = false, features = ["ahash"]}

# WARNING: Do not update. The version was choosen for a reason. See wdk/README.md for more detiels.
[dependencies.windows-sys]
Expand Down
40 changes: 18 additions & 22 deletions windows_kext/driver/src/bandwidth.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
use alloc::collections::BTreeMap;
use protocol::info::{BandwidthValueV4, BandwidthValueV6, Info};
use smoltcp::wire::{IpProtocol, Ipv4Address, Ipv6Address};
use wdk::rw_spin_lock::RwSpinLock;

use crate::driver_hashmap::DeviceHashMap;

#[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Default)]
pub struct Key<Address>
where
Address: Eq + PartialEq,
{
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Default)]
pub struct Key<Address: Ord> {
pub local_ip: Address,
pub local_port: u16,
pub remote_ip: Address,
Expand All @@ -25,32 +21,32 @@ enum Direction {
Rx(usize),
}
pub struct Bandwidth {
stats_tcp_v4: DeviceHashMap<Key<Ipv4Address>, Value>,
stats_tcp_v4: BTreeMap<Key<Ipv4Address>, Value>,
stats_tcp_v4_lock: RwSpinLock,

stats_tcp_v6: DeviceHashMap<Key<Ipv6Address>, Value>,
stats_tcp_v6: BTreeMap<Key<Ipv6Address>, Value>,
stats_tcp_v6_lock: RwSpinLock,

stats_udp_v4: DeviceHashMap<Key<Ipv4Address>, Value>,
stats_udp_v4: BTreeMap<Key<Ipv4Address>, Value>,
stats_udp_v4_lock: RwSpinLock,

stats_udp_v6: DeviceHashMap<Key<Ipv6Address>, Value>,
stats_udp_v6: BTreeMap<Key<Ipv6Address>, Value>,
stats_udp_v6_lock: RwSpinLock,
}

impl Bandwidth {
pub fn new() -> Self {
Self {
stats_tcp_v4: DeviceHashMap::new(),
stats_tcp_v4: BTreeMap::new(),
stats_tcp_v4_lock: RwSpinLock::default(),

stats_tcp_v6: DeviceHashMap::new(),
stats_tcp_v6: BTreeMap::new(),
stats_tcp_v6_lock: RwSpinLock::default(),

stats_udp_v4: DeviceHashMap::new(),
stats_udp_v4: BTreeMap::new(),
stats_udp_v4_lock: RwSpinLock::default(),

stats_udp_v6: DeviceHashMap::new(),
stats_udp_v6: BTreeMap::new(),
stats_udp_v6_lock: RwSpinLock::default(),
}
}
Expand All @@ -62,7 +58,7 @@ impl Bandwidth {
if self.stats_tcp_v4.is_empty() {
return None;
}
stats_map = core::mem::replace(&mut self.stats_tcp_v4, DeviceHashMap::new());
stats_map = core::mem::replace(&mut self.stats_tcp_v4, BTreeMap::new());
}

let mut values = alloc::vec::Vec::with_capacity(stats_map.len());
Expand All @@ -89,7 +85,7 @@ impl Bandwidth {
if self.stats_tcp_v6.is_empty() {
return None;
}
stats_map = core::mem::replace(&mut self.stats_tcp_v6, DeviceHashMap::new());
stats_map = core::mem::replace(&mut self.stats_tcp_v6, BTreeMap::new());
}

let mut values = alloc::vec::Vec::with_capacity(stats_map.len());
Expand All @@ -116,7 +112,7 @@ impl Bandwidth {
if self.stats_udp_v4.is_empty() {
return None;
}
stats_map = core::mem::replace(&mut self.stats_udp_v4, DeviceHashMap::new());
stats_map = core::mem::replace(&mut self.stats_udp_v4, BTreeMap::new());
}

let mut values = alloc::vec::Vec::with_capacity(stats_map.len());
Expand All @@ -140,10 +136,10 @@ impl Bandwidth {
let stats_map;
{
let _guard = self.stats_udp_v6_lock.write_lock();
if self.stats_tcp_v6.is_empty() {
if self.stats_udp_v6.is_empty() {
return None;
}
stats_map = core::mem::replace(&mut self.stats_tcp_v6, DeviceHashMap::new());
stats_map = core::mem::replace(&mut self.stats_udp_v6, BTreeMap::new());
}

let mut values = alloc::vec::Vec::with_capacity(stats_map.len());
Expand Down Expand Up @@ -235,8 +231,8 @@ impl Bandwidth {
);
}

fn update<Address: Eq + PartialEq + core::hash::Hash>(
map: &mut DeviceHashMap<Key<Address>, Value>,
fn update<Address: Ord>(
map: &mut BTreeMap<Key<Address>, Value>,
lock: &mut RwSpinLock,
key: Key<Address>,
bytes: Direction,
Expand Down
73 changes: 1 addition & 72 deletions windows_kext/driver/src/connection_cache.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
use core::time::Duration;

use crate::{
connection::{Connection, ConnectionV4, ConnectionV6, RedirectInfo, Verdict},
connection_map::{ConnectionMap, Key},
};
use alloc::{format, string::String, vec::Vec};
use alloc::vec::Vec;

use smoltcp::wire::IpProtocol;
use wdk::rw_spin_lock::RwSpinLock;
Expand Down Expand Up @@ -128,73 +126,4 @@ impl ConnectionCache {

return size;
}

#[allow(dead_code)]
pub fn get_full_cache_info(&self) -> String {
let mut info = String::new();
let now = wdk::utils::get_system_timestamp_ms();
{
let _guard = self.lock_v4.read_lock();
for ((protocol, port), connections) in self.connections_v4.iter() {
info.push_str(&format!("{} -> {}\n", protocol, port,));
for conn in connections {
let active_time_seconds =
Duration::from_millis(now - conn.get_last_accessed_time()).as_secs();
info.push_str(&format!(
"\t{}:{} -> {}:{} {} last active {}m {}s ago",
conn.local_address,
conn.local_port,
conn.remote_address,
conn.remote_port,
conn.verdict,
active_time_seconds / 60,
active_time_seconds % 60
));
if conn.has_ended() {
let end_time_seconds =
Duration::from_millis(now - conn.get_end_time()).as_secs();
info.push_str(&format!(
"\t ended {}m {}s ago",
end_time_seconds / 60,
end_time_seconds % 60
));
}
info.push('\n');
}
}
}

{
let _guard = self.lock_v6.read_lock();
for ((protocol, port), connections) in self.connections_v6.iter() {
info.push_str(&format!("{} -> {} \n", protocol, port));
for conn in connections {
let active_time_seconds =
Duration::from_millis(now - conn.get_last_accessed_time()).as_secs();
info.push_str(&format!(
"\t{}:{} -> {}:{} {} last active {}m {}s ago",
conn.local_address,
conn.local_port,
conn.remote_address,
conn.remote_port,
conn.verdict,
active_time_seconds / 60,
active_time_seconds % 60
));
if conn.has_ended() {
let end_time_seconds =
Duration::from_millis(now - conn.get_end_time()).as_secs();
info.push_str(&format!(
"\t ended {}m {}s ago",
end_time_seconds / 60,
end_time_seconds % 60
));
}
info.push('\n');
}
}
}

return info;
}
}
12 changes: 3 additions & 9 deletions windows_kext/driver/src/connection_map.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use core::{fmt::Display, time::Duration};

use crate::connection::Connection;
use alloc::vec::Vec;
use hashbrown::HashMap;
use alloc::{collections::BTreeMap, vec::Vec};
use smoltcp::wire::{IpAddress, IpProtocol};

#[derive(Clone, Copy, PartialEq, PartialOrd, Eq, Ord)]
Expand Down Expand Up @@ -63,11 +62,11 @@ impl Key {
}
}

pub struct ConnectionMap<T: Connection>(HashMap<(IpProtocol, u16), Vec<T>>);
pub struct ConnectionMap<T: Connection>(BTreeMap<(IpProtocol, u16), Vec<T>>);

impl<T: Connection + Clone> ConnectionMap<T> {
pub fn new() -> Self {
Self(HashMap::new())
Self(BTreeMap::new())
}

pub fn add(&mut self, conn: T) {
Expand Down Expand Up @@ -164,16 +163,11 @@ impl<T: Connection + Clone> ConnectionMap<T> {
self.0.retain(|_, v| !v.is_empty());
}

#[allow(dead_code)]
pub fn get_count(&self) -> usize {
let mut count = 0;
for conn in self.0.values() {
count += conn.len();
}
return count;
}

pub fn iter(&self) -> hashbrown::hash_map::Iter<'_, (IpProtocol, u16), Vec<T>> {
self.0.iter()
}
}
25 changes: 0 additions & 25 deletions windows_kext/driver/src/driver_hashmap.rs

This file was deleted.

1 change: 0 additions & 1 deletion windows_kext/driver/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ mod connection;
mod connection_cache;
mod connection_map;
mod device;
mod driver_hashmap;
mod entry;
mod id_cache;
pub mod logger;
Expand Down
Loading

0 comments on commit 355f743

Please sign in to comment.