From 49adecb25ce4f456f15e47cb2f49ea611f26957d Mon Sep 17 00:00:00 2001 From: Yacine Elhamer Date: Sat, 26 Aug 2023 18:11:35 +0200 Subject: [PATCH] add yaml representer for the Scope class, as well as other bugfixes --- capa/features/extractors/null.py | 6 ++++++ capa/rules/__init__.py | 5 +++++ tests/test_result_document.py | 1 - 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/capa/features/extractors/null.py b/capa/features/extractors/null.py index 50bd85114..f6797aa96 100644 --- a/capa/features/extractors/null.py +++ b/capa/features/extractors/null.py @@ -59,6 +59,9 @@ class NullStaticFeatureExtractor(StaticFeatureExtractor): def get_base_address(self): return self.base_address + def get_sample_hashes(self) -> SampleHashes: + return self.sample_hashes + def extract_global_features(self): for feature in self.global_features: yield feature, NO_ADDRESS @@ -121,6 +124,9 @@ def extract_global_features(self): for feature in self.global_features: yield feature, NO_ADDRESS + def get_sample_hashes(self) -> SampleHashes: + return self.sample_hashes + def extract_file_features(self): for address, feature in self.file_features: yield feature, address diff --git a/capa/rules/__init__.py b/capa/rules/__init__.py index 04ea11bd2..35f2a0907 100644 --- a/capa/rules/__init__.py +++ b/capa/rules/__init__.py @@ -86,6 +86,10 @@ class Scope(str, Enum): # not used to validate rules. GLOBAL = "global" + @classmethod + def to_yaml(cls, representer, node): + return representer.represent_str(f"{node.value}") + # these literals are used to check if the flavor # of a rule is correct. @@ -979,6 +983,7 @@ def _get_ruamel_yaml_parser(): # we use the ruamel.yaml parser because it supports roundtripping of documents with comments. y = ruamel.yaml.YAML(typ="rt") + y.register_class(Scope) # use block mode, not inline json-like mode y.default_flow_style = False diff --git a/tests/test_result_document.py b/tests/test_result_document.py index 0311a1d69..10f022d94 100644 --- a/tests/test_result_document.py +++ b/tests/test_result_document.py @@ -263,7 +263,6 @@ def assert_round_trip(rd: rdoc.ResultDocument): pytest.param("a076114_rd"), pytest.param("pma0101_rd"), pytest.param("dotnet_1c444e_rd"), - pytest.param(""), ], ) def test_round_trip(request, rd_file):