diff --git a/README.md b/README.md index 7bc4273..d57a432 100644 --- a/README.md +++ b/README.md @@ -36,9 +36,9 @@ Sample: `curl http://localhost:5050/log` --- -#### GET /mark-benign/{anomaly_uuid} +#### GET /mark-benign/{anomaly_uuid}|all -Description: Mark anomaly with provided UUID as benign. +Description: Mark anomaly with provided UUID as benign. One can also use `all` in place of a UUID to clear all current in-memory anomalies. Sample: `curl http://localhost:5050/mark-benign/00000000-1234-1234-1234-123456789012` diff --git a/aud_manager/aud.py b/aud_manager/aud.py index 155b520..adf0a25 100644 --- a/aud_manager/aud.py +++ b/aud_manager/aud.py @@ -25,6 +25,12 @@ class ACLKey(NamedTuple): addr: str svc_port: int +class FreqKey(NamedTuple): + ip_ver: int + direction: str + proto: int + svc_port: int + class Direction(Enum): FWD = 0 REV = 1 @@ -42,7 +48,7 @@ class Severity(Enum): Alarming = 4 class Anomaly(): - def __init__(self, category=Category.Undefined, conn=None): + def __init__(self, category=Category.Undefined, conn=None, score=0.0): self.time = datetime.now(timezone.utc).replace(microsecond=0) self.uuid = uuid.uuid4() self.conn = conn @@ -55,25 +61,41 @@ def __init__(self, category=Category.Undefined, conn=None): self.category = category self.severity = Severity.Unknown - self.score = 0.0 + self.score = score self.post_to_dht() def as_dict(self): acl_key = self.conn.get_acl_key() + if acl_key.proto == 1: + svc_port = None + else: + svc_port = acl_key.svc_port + + if self.category == Category.FrequentFlow: + details = { + "direction": str(acl_key.direction), + "proto": l4proto[acl_key.proto], + "svc_port": str(svc_port), + "ip_ver": acl_key.ip_ver + } + else: + details = { + "direction": str(acl_key.direction), + "proto": l4proto[acl_key.proto], + "svc_port": str(svc_port), + "addr": str(acl_key.addr), + "ip_ver": acl_key.ip_ver, + } + return { "anomaly_uuid": str(self.uuid), "time": str(self.time), "category": str(self.category.name), "severity": str(self.severity.name), "score": str(round(self.score, 3)), - "details": { - "direction": str(acl_key.direction), - "proto": l4proto[acl_key.proto]+":"+str(acl_key.svc_port), - "addr": str(acl_key.addr), - "ip_ver": acl_key.ip_ver, - } + "details": details } def post_to_dht(self): @@ -220,6 +242,33 @@ def stats_update(self): def pep_distribution(self): return Counter(self.peps) +class FrequencyCounter: + def __init__(self, ws, thresh): + self.winsize = ws * 1000000000 + self.threshold = thresh + self.counters = dict() + self.connref = dict() + + def __str__(self): + return str(self.counters) + + def add(self, conn): + key = conn.get_freq_key() + if key not in self.counters: + self.counters[key] = [] + self.connref[key] = conn + + self.counters[key].append(conn.created_ns) + + def evaluate(self): + now = time.time_ns() + + for counter, timestamps in self.counters.items(): + timestamps[:] = [ts for ts in timestamps if ts > (now - self.winsize)] + if len(timestamps) > self.threshold: + ratio = round((len(timestamps) / self.threshold), 3) + yield Anomaly(category=Category.FrequentFlow, conn=self.connref[counter], score=ratio) + class AUDRecord: def __init__(self, aud_handle): @@ -239,11 +288,13 @@ def as_dict(self): def process(self, connlist): for conn in connlist: if self.remote_as is None: - self.aud.anomalies.append(Anomaly(category=Category.NovelFlow, conn=conn)) + #self.aud.anomalies.append(Anomaly(category=Category.NovelFlow, conn=conn)) ### TODO: Resolve remote AS based on acl_key.addr self.remote_as = "Unresolved/FIXTHIS" - # AS score and evaluation TODO + if conn.new: + self.aud.freq_counter.add(conn) + conn.new = False if conn.active(): # Do not aggregate stats over partial flow records @@ -259,8 +310,8 @@ def process(self, connlist): conn.marked_for_deletion = True - def evaluate(self, category, conn): - logging.debug("evaluate") + def evaluate(self): + pass class AUD: @@ -268,6 +319,7 @@ def __init__(self): self.global_conn_counter = 0 self.last_updated = 0 self.records = dict() + self.freq_counter = FrequencyCounter(30, 30) self.anomalies = deque(maxlen=100) @@ -290,7 +342,22 @@ def update(self, connlist): self.records[key].process(connlist.conns_by_acl_key(key)) + def evaluate(self): + count = 0 + for record in self.records.values(): + record.evaluate() + + for result in self.freq_counter.evaluate(): + self.anomalies.append(result) + count += 1 + + return count + def mark_benign(self, input_uuid_string): + if input_uuid_string == "all": + self.anomalies.clear() + return "OK" + try: needle = uuid.UUID(input_uuid_string) except ValueError as ve: diff --git a/aud_manager/aud_conn.py b/aud_manager/aud_conn.py index f01f987..4559838 100644 --- a/aud_manager/aud_conn.py +++ b/aud_manager/aud_conn.py @@ -27,8 +27,6 @@ def __init__(self, aud_handle): self.lookup = dict() self.conns = list() - self.timeout = 60*1000000 - def __len__(self): return len(self.conns) @@ -99,6 +97,7 @@ def aggregate_acl_keys(self): class ConnEntry(): def __init__(self, key, l3hdr, l4hdr): self.key = key + self.new = True if l3hdr.direction == pr.socket.PACKET_HOST: self.acl_direction = "inbound" # to @@ -150,6 +149,12 @@ def get_acl_key(self): addr = self.acl_addr, svc_port = self.key.dst_port) + def get_freq_key(self): + return aud.FreqKey(ip_ver = self.local_ip.version, + direction = self.acl_direction, + proto = self.key.proto, + svc_port = self.key.dst_port) + def append(self, direction, t, plen, flags): self.data.add(t, plen, direction) self.last_updated = t diff --git a/aud_manager/aud_manager.py b/aud_manager/aud_manager.py index a7119ed..9151789 100644 --- a/aud_manager/aud_manager.py +++ b/aud_manager/aud_manager.py @@ -89,6 +89,7 @@ def run(self): if aud_update_t < time.time(): self.aud_update() + self.aud_evaluate() self.connlist.trim() aud_update_t = time.time() + self.aud_update_interval @@ -115,6 +116,11 @@ def aud_update(self): self.aud.update(self.connlist) logging.debug("aud_update() finished in %f seconds.", round((time.time() - start_t), 3)) + def aud_evaluate(self): + start_t = time.time() + res = self.aud.evaluate() + logging.debug("aud_evaluate() finished in %f seconds. %d anomalies reported", round((time.time() - start_t), 3), res) + def response(self, res): return json.dumps({"response": str(res)})