diff --git a/CHANGELOG.rst b/CHANGELOG.rst index bb04eb2e92..ff364d3094 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -6,6 +6,9 @@ in development Python 3.6 is no longer supported; Stackstorm requires at least Python 3.8. +* implemented zstandard compression for parameters and results. #5995 + contributed by @guzzijones12 + Fixed ~~~~~ * Restore Pack integration testing (it was inadvertently skipped) and stop testing against `bionic` and `el7`. #6135 diff --git a/conf/st2.conf.sample b/conf/st2.conf.sample index 82da2ae2e1..d812580046 100644 --- a/conf/st2.conf.sample +++ b/conf/st2.conf.sample @@ -138,6 +138,9 @@ connection_timeout = 3000 db_name = st2 # host of db server host = 127.0.0.1 +# compression for parameter and result storage in liveaction and execution models +# Valid values: zstandard, none +parameter_result_compression = zstandard # password for db login password = None # port of db server diff --git a/conf/st2.dev.conf b/conf/st2.dev.conf index cf2b5b6596..c2276ec332 100644 --- a/conf/st2.dev.conf +++ b/conf/st2.dev.conf @@ -1,6 +1,7 @@ # Config used by local development environment (tools/launch.dev.sh) [database] host = 127.0.0.1 +parameter_result_compression = zstandard [api] # Host and port to bind the API server. diff --git a/contrib/runners/orquesta_runner/orquesta_runner/orquesta_runner.py b/contrib/runners/orquesta_runner/orquesta_runner/orquesta_runner.py index 586a9d0cc9..4e5db3dfe8 100644 --- a/contrib/runners/orquesta_runner/orquesta_runner/orquesta_runner.py +++ b/contrib/runners/orquesta_runner/orquesta_runner/orquesta_runner.py @@ -136,12 +136,22 @@ def start_workflow(self, action_parameters): wf_def, self.execution, st2_ctx, notify_cfg=notify_cfg ) except wf_exc.WorkflowInspectionError as e: + _, ex, tb = sys.exc_info() status = ac_const.LIVEACTION_STATUS_FAILED - result = {"errors": e.args[1], "output": None} + result = { + "errors": e.args[1], + "output": None, + "traceback": "".join(traceback.format_tb(tb, 20)), + } return (status, result, self.context) except Exception as e: + _, ex, tb = sys.exc_info() status = ac_const.LIVEACTION_STATUS_FAILED - result = {"errors": [{"message": six.text_type(e)}], "output": None} + result = { + "errors": [{"message": six.text_type(e)}], + "output": None, + "traceback": "".join(traceback.format_tb(tb, 20)), + } return (status, result, self.context) return self._handle_workflow_return_value(wf_ex_db) diff --git a/contrib/runners/orquesta_runner/tests/unit/test_error_handling.py b/contrib/runners/orquesta_runner/tests/unit/test_error_handling.py index 9a4dd1cd5b..a11d2eed11 100644 --- a/contrib/runners/orquesta_runner/tests/unit/test_error_handling.py +++ b/contrib/runners/orquesta_runner/tests/unit/test_error_handling.py @@ -363,11 +363,21 @@ def test_fail_start_task_input_value_type(self): workflow_execution=str(wf_ex_db.id) )[0] self.assertEqual(tk_ex_db.status, wf_statuses.FAILED) - self.assertDictEqual(tk_ex_db.result, {"errors": expected_errors}) + self.assertEqual( + tk_ex_db.result["errors"][0]["type"], expected_errors[0]["type"] + ) + self.assertEqual( + tk_ex_db.result["errors"][0]["message"], expected_errors[0]["message"] + ) + self.assertEqual( + tk_ex_db.result["errors"][0]["task_id"], expected_errors[0]["task_id"] + ) + self.assertEqual( + tk_ex_db.result["errors"][0]["route"], expected_errors[0]["route"] + ) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED) - self.assertDictEqual(lv_ac_db.result, expected_result) ac_ex_db = ex_db_access.ActionExecution.get_by_id(str(ac_ex_db.id)) self.assertEqual(ac_ex_db.status, ac_const.LIVEACTION_STATUS_FAILED) @@ -522,13 +532,37 @@ def test_fail_next_task_input_value_type(self): # Assert workflow execution and task2 execution failed. wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(str(wf_ex_db.id)) self.assertEqual(wf_ex_db.status, wf_statuses.FAILED) - self.assertListEqual( - self.sort_workflow_errors(wf_ex_db.errors), expected_errors + self.assertEqual( + self.sort_workflow_errors(wf_ex_db.errors)[0]["type"], + expected_errors[0]["type"], + ) + self.assertEqual( + self.sort_workflow_errors(wf_ex_db.errors)[0]["message"], + expected_errors[0]["message"], + ) + self.assertEqual( + self.sort_workflow_errors(wf_ex_db.errors)[0]["task_id"], + expected_errors[0]["task_id"], + ) + self.assertEqual( + self.sort_workflow_errors(wf_ex_db.errors)[0]["route"], + expected_errors[0]["route"], ) tk2_ex_db = wf_db_access.TaskExecution.query(task_id="task2")[0] self.assertEqual(tk2_ex_db.status, wf_statuses.FAILED) - self.assertDictEqual(tk2_ex_db.result, {"errors": expected_errors}) + self.assertEqual( + tk2_ex_db.result["errors"][0]["type"], expected_errors[0]["type"] + ) + self.assertEqual( + tk2_ex_db.result["errors"][0]["message"], expected_errors[0]["message"] + ) + self.assertEqual( + tk2_ex_db.result["errors"][0]["task_id"], expected_errors[0]["task_id"] + ) + self.assertEqual( + tk2_ex_db.result["errors"][0]["route"], expected_errors[0]["route"] + ) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED) diff --git a/contrib/runners/orquesta_runner/tests/unit/test_notify.py b/contrib/runners/orquesta_runner/tests/unit/test_notify.py index ff7114a318..2e7bfcad4f 100644 --- a/contrib/runners/orquesta_runner/tests/unit/test_notify.py +++ b/contrib/runners/orquesta_runner/tests/unit/test_notify.py @@ -235,7 +235,11 @@ def test_notify_task_list_nonexistent_task(self): } self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_FAILED) - self.assertDictEqual(lv_ac_db.result, expected_result) + self.assertEqual( + lv_ac_db.result["errors"][0]["message"], + expected_result["errors"][0]["message"], + ) + self.assertIsNone(lv_ac_db.result["output"], expected_result["output"]) def test_notify_task_list_item_value(self): wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") diff --git a/st2actions/st2actions/container/base.py b/st2actions/st2actions/container/base.py index 71fe218292..6b441c0277 100644 --- a/st2actions/st2actions/container/base.py +++ b/st2actions/st2actions/container/base.py @@ -141,10 +141,13 @@ def _do_run(self, runner): ): queries.setup_query(runner.liveaction.id, runner.runner_type, context) except: - LOG.exception("Failed to run action.") _, ex, tb = sys.exc_info() # mark execution as failed. status = action_constants.LIVEACTION_STATUS_FAILED + LOG.exception( + "Failed to run action. traceback: %s" + % "".join(traceback.format_tb(tb, 20)) + ) # include the error message and traceback to try and provide some hints. result = { "error": str(ex), diff --git a/st2actions/st2actions/policies/concurrency_by_attr.py b/st2actions/st2actions/policies/concurrency_by_attr.py index 9f503bf18a..6ba064f7e5 100644 --- a/st2actions/st2actions/policies/concurrency_by_attr.py +++ b/st2actions/st2actions/policies/concurrency_by_attr.py @@ -15,10 +15,9 @@ from __future__ import absolute_import -import six - from st2common.constants import action as action_constants from st2common import log as logging +from st2common.fields import JSONDictEscapedFieldCompatibilityField from st2common.persistence import action as action_access from st2common.services import action as action_service from st2common.policies.concurrency import BaseConcurrencyApplicator @@ -41,31 +40,43 @@ def __init__( ) self.attributes = attributes or [] - def _get_filters(self, target): - filters = { - ("parameters__%s" % k): v - for k, v in six.iteritems(target.parameters) - if k in self.attributes - } - - filters["action"] = target.action - filters["status"] = None - - return filters - def _apply_before(self, target): - # Get the count of scheduled and running instances of the action. - filters = self._get_filters(target) - # Get the count of scheduled instances of the action. - filters["status"] = action_constants.LIVEACTION_STATUS_SCHEDULED - scheduled = action_access.LiveAction.count(**filters) + scheduled_filters = { + "status": action_constants.LIVEACTION_STATUS_SCHEDULED, + "action": target.action, + } + scheduled = [i for i in action_access.LiveAction.query(**scheduled_filters)] - # Get the count of running instances of the action. - filters["status"] = action_constants.LIVEACTION_STATUS_RUNNING - running = action_access.LiveAction.count(**filters) + running_filters = { + "status": action_constants.LIVEACTION_STATUS_RUNNING, + "action": target.action, + } + running = [i for i in action_access.LiveAction.query(**running_filters)] + running.extend(scheduled) + count = 0 + target_parameters = JSONDictEscapedFieldCompatibilityField().parse_field_value( + target.parameters + ) + target_key_value_policy_attributes = { + k: v for k, v in target_parameters.items() if k in self.attributes + } - count = scheduled + running + for i in running: + running_event_parameters = ( + JSONDictEscapedFieldCompatibilityField().parse_field_value(i.parameters) + ) + # list of event parameter values that are also in policy + running_event_policy_item_key_value_attributes = { + k: v + for k, v in running_event_parameters.items() + if k in self.attributes + } + if ( + running_event_policy_item_key_value_attributes + == target_key_value_policy_attributes + ): + count += 1 # Mark the execution as scheduled if threshold is not reached or delayed otherwise. if count < self.threshold: diff --git a/st2api/st2api/controllers/v1/actionexecutions.py b/st2api/st2api/controllers/v1/actionexecutions.py index 70d709192e..32533f2430 100644 --- a/st2api/st2api/controllers/v1/actionexecutions.py +++ b/st2api/st2api/controllers/v1/actionexecutions.py @@ -39,6 +39,7 @@ from st2common.exceptions import apivalidation as validation_exc from st2common.exceptions import param as param_exc from st2common.exceptions import trace as trace_exc +from st2common.fields import JSONDictEscapedFieldCompatibilityField from st2common.models.api.action import LiveActionAPI from st2common.models.api.action import LiveActionCreateAPI from st2common.models.api.base import cast_argument_value @@ -416,36 +417,18 @@ def get( :rtype: ``str`` """ - # NOTE: Here we intentionally use as_pymongo() to avoid mongoengine layer even for old style - # data + # NOTE: we need to use to_python() to uncompress the data try: result = ( - self.access.impl.model.objects.filter(id=id) - .only("result") - .as_pymongo()[0] + self.access.impl.model.objects.filter(id=id).only("result")[0].result ) except IndexError: raise NotFoundException("Execution with id %s not found" % (id)) - if isinstance(result["result"], dict): - # For backward compatibility we also support old non JSON field storage format - if pretty_format: - response_body = orjson.dumps( - result["result"], option=orjson.OPT_INDENT_2 - ) - else: - response_body = orjson.dumps(result["result"]) + if pretty_format: + response_body = orjson.dumps(result, option=orjson.OPT_INDENT_2) else: - # For new JSON storage format we just use raw value since it's already JSON serialized - # string - response_body = result["result"] - - if pretty_format: - # Pretty format is not a default behavior since it adds quite some overhead (e.g. - # 10-30ms for non pretty format for 4 MB json vs ~120 ms for pretty formatted) - response_body = orjson.dumps( - orjson.loads(result["result"]), option=orjson.OPT_INDENT_2 - ) + response_body = orjson.dumps(result) response = Response() response.headers["Content-Type"] = "text/json" @@ -634,8 +617,14 @@ def post(self, spec_api, id, requester_user, no_merge=False, show_secrets=False) # Merge in any parameters provided by the user new_parameters = {} + original_parameters = getattr(existing_execution, "parameters", b"{}") + original_params_decoded = ( + JSONDictEscapedFieldCompatibilityField().parse_field_value( + original_parameters + ) + ) if not no_merge: - new_parameters.update(getattr(existing_execution, "parameters", {})) + new_parameters.update(original_params_decoded) new_parameters.update(spec_api.parameters) # Create object for the new execution diff --git a/st2common/benchmarks/micro/test_mongo_field_types.py b/st2common/benchmarks/micro/test_mongo_field_types.py index 54e5ead509..e9b3077d3b 100644 --- a/st2common/benchmarks/micro/test_mongo_field_types.py +++ b/st2common/benchmarks/micro/test_mongo_field_types.py @@ -46,6 +46,7 @@ import pytest import mongoengine as me +import orjson from st2common.service_setup import db_setup from st2common.models.db import stormbase @@ -62,7 +63,50 @@ LiveActionDB._meta["allow_inheritance"] = True # pylint: disable=no-member -# 1. Current approach aka using EscapedDynamicField +class OldJSONDictField(JSONDictField): + def parse_field_value(self, value) -> dict: + """ + Parse provided binary field value and return parsed value (dictionary). + + For example: + + - (n, o, ...) - no compression, data is serialized using orjson + - (z, o, ...) - zstandard compression, data is serialized using orjson + """ + if not value: + return self.default + + if isinstance(value, dict): + # Already deserializaed + return value + + data = orjson.loads(value) + return data + + def _serialize_field_value(self, value: dict) -> bytes: + """ + Serialize and encode the provided field value. + """ + # Orquesta workflows support toSet() YAQL operator which returns a set which used to get + # serialized to list by mongoengine DictField. + # + # For backward compatibility reasons, we need to support serializing set to a list as + # well. + # + # Based on micro benchmarks, using default function adds very little overhead (1%) so it + # should be safe to use default for every operation. + # + # If this turns out to be not true or it adds more overhead in other scenarios, we should + # revisit this decision and only use "default" argument where needed (aka Workflow models). + def default(obj): + if isinstance(obj, set): + return list(obj) + raise TypeError + + return orjson.dumps(value, default=default) + + +# 1. old approach aka using EscapedDynamicField class LiveActionDB_EscapedDynamicField(LiveActionDB): result = stormbase.EscapedDynamicField(default={}) @@ -71,46 +115,31 @@ class LiveActionDB_EscapedDynamicField(LiveActionDB): field3 = stormbase.EscapedDynamicField(default={}) -# 2. Current approach aka using EscapedDictField +# 2. old approach aka using EscapedDictField class LiveActionDB_EscapedDictField(LiveActionDB): result = stormbase.EscapedDictField(default={}) - field1 = stormbase.EscapedDynamicField(default={}, use_header=False) - field2 = stormbase.EscapedDynamicField(default={}, use_header=False) - field3 = stormbase.EscapedDynamicField(default={}, use_header=False) - - -# 3. Approach which uses new JSONDictField where value is stored as serialized JSON string / blob -class LiveActionDB_JSONField(LiveActionDB): - result = JSONDictField(default={}, use_header=False) - - field1 = JSONDictField(default={}, use_header=False) - field2 = JSONDictField(default={}, use_header=False) - field3 = JSONDictField(default={}, use_header=False) + field1 = stormbase.EscapedDynamicField(default={}) + field2 = stormbase.EscapedDynamicField(default={}) + field3 = stormbase.EscapedDynamicField(default={}) -class LiveActionDB_JSONFieldWithHeader(LiveActionDB): - result = JSONDictField(default={}, use_header=True, compression_algorithm="none") +# 3. Old Approach which uses no compression where value is stored as serialized JSON string / blob +class LiveActionDB_OLDJSONField(LiveActionDB): + result = OldJSONDictField(default={}, use_header=False) - field1 = JSONDictField(default={}, use_header=True, compression_algorithm="none") - field2 = JSONDictField(default={}, use_header=True, compression_algorithm="none") - field3 = JSONDictField(default={}, use_header=True, compression_algorithm="none") + field1 = OldJSONDictField(default={}) + field2 = OldJSONDictField(default={}) + field3 = OldJSONDictField(default={}) -class LiveActionDB_JSONFieldWithHeaderAndZstandard(LiveActionDB): - result = JSONDictField( - default={}, use_header=True, compression_algorithm="zstandard" - ) +# 4. Current Approach which uses new JSONDictField where value is stored as zstandard compressed serialized JSON string / blob +class LiveActionDB_JSONField(LiveActionDB): + result = JSONDictField(default={}, use_header=False) - field1 = JSONDictField( - default={}, use_header=True, compression_algorithm="zstandard" - ) - field2 = JSONDictField( - default={}, use_header=True, compression_algorithm="zstandard" - ) - field3 = JSONDictField( - default={}, use_header=True, compression_algorithm="zstandard" - ) + field1 = JSONDictField(default={}) + field2 = JSONDictField(default={}) + field3 = JSONDictField(default={}) class LiveActionDB_StringField(LiveActionDB): @@ -128,10 +157,8 @@ def get_model_class_for_approach(approach: str) -> Type[LiveActionDB]: model_cls = LiveActionDB_EscapedDictField elif approach == "json_dict_field": model_cls = LiveActionDB_JSONField - elif approach == "json_dict_field_with_header": - model_cls = LiveActionDB_JSONFieldWithHeader - elif approach == "json_dict_field_with_header_and_zstd": - model_cls = LiveActionDB_JSONFieldWithHeaderAndZstandard + elif approach == "old_json_dict_field": + model_cls = LiveActionDB_OLDJSONField else: raise ValueError("Invalid approach: %s" % (approach)) @@ -142,18 +169,12 @@ def get_model_class_for_approach(approach: str) -> Type[LiveActionDB]: @pytest.mark.parametrize( "approach", [ - "escaped_dynamic_field", - "escaped_dict_field", + "old_json_dict_field", "json_dict_field", - "json_dict_field_with_header", - "json_dict_field_with_header_and_zstd", ], ids=[ - "escaped_dynamic_field", - "escaped_dict_field", + "old_json_dict_field", "json_dict_field", - "json_dict_field_w_header", - "json_dict_field_w_header_and_zstd", ], ) @pytest.mark.benchmark(group="live_action_save") @@ -187,18 +208,12 @@ def run_benchmark(): @pytest.mark.parametrize( "approach", [ - "escaped_dynamic_field", - "escaped_dict_field", + "old_json_dict_field", "json_dict_field", - "json_dict_field_with_header", - "json_dict_field_with_header_and_zstd", ], ids=[ - "escaped_dynamic_field", - "escaped_dict_field", + "old_json_dict_field", "json_dict_field", - "json_dict_field_w_header", - "json_dict_field_w_header_and_zstd", ], ) @pytest.mark.benchmark(group="live_action_save_multiple_fields") @@ -240,18 +255,12 @@ def run_benchmark(): @pytest.mark.parametrize( "approach", [ - "escaped_dynamic_field", - "escaped_dict_field", + "old_json_dict_field", "json_dict_field", - "json_dict_field_with_header", - "json_dict_field_with_header_and_zstd", ], ids=[ - "escaped_dynamic_field", - "escaped_dict_field", + "old_json_dict_field", "json_dict_field", - "json_dict_field_w_header", - "json_dict_field_w_header_and_zstd", ], ) @pytest.mark.benchmark(group="live_action_read") diff --git a/st2common/bin/migrations/v3.9/BUILD b/st2common/bin/migrations/v3.9/BUILD new file mode 100644 index 0000000000..255bf31004 --- /dev/null +++ b/st2common/bin/migrations/v3.9/BUILD @@ -0,0 +1,3 @@ +python_sources( + sources=["*.py", "st2*"], +) diff --git a/st2common/bin/migrations/v3.9/__init__.py b/st2common/bin/migrations/v3.9/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/st2common/bin/migrations/v3.9/st2-migrate-liveaction-executiondb b/st2common/bin/migrations/v3.9/st2-migrate-liveaction-executiondb new file mode 100755 index 0000000000..76e69bc845 --- /dev/null +++ b/st2common/bin/migrations/v3.9/st2-migrate-liveaction-executiondb @@ -0,0 +1,212 @@ +#!/usr/bin/env python3 +# Copyright 2021 The StackStorm Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Migration which migrates data for existing objects in the database which utilize +liveaction to a string + +Migration step is idempotent and can be retried on failures / partial runs. + +Right now the script utilizes no concurrency and performs migration one object by one. That's done +for simplicity reasons and also to avoid massive CPU usage spikes when running this script with +large concurrency on large objects. + +Keep in mind that only "completed" objects are processed - this means Executions in "final" states +(succeeded, failed, timeout, etc.). + +We determine if an object should be migrating using mongodb $type query (for execution objects we +could also determine that based on the presence of result_size field). +""" + +import sys +import datetime +import time +import traceback + +from oslo_config import cfg + +from st2common import config +from st2common.service_setup import db_setup +from st2common.service_setup import db_teardown +from st2common.util import isotime +from st2common.models.db.execution import ActionExecutionDB +from st2common.constants.action import ( + LIVEACTION_COMPLETED_STATES, + LIVEACTION_STATUS_PAUSED, + LIVEACTION_STATUS_PENDING, +) + +# NOTE: To avoid unnecessary mongoengine object churn when retrieving only object ids (aka to avoid +# instantiating model class with a single field), we use raw pymongo value which is a dict with a +# single value + + +def migrate_executions(start_dt: datetime.datetime, end_dt: datetime.datetime) -> None: + """ + Perform migrations for execution related objects (ActionExecutionDB, LiveActionDB). + """ + print("Migrating execution objects") + + # NOTE: We first only retrieve the IDs because there could be a lot of objects in the database + # and this could result in massive ram use. Technically, mongoengine loads querysets lazily, + # but this is not always the case so it's better to first retrieve all the IDs and then retrieve + # objects one by one. + # Keep in mind we need to use ModelClass.objects and not PersistanceClass.query() so .only() + # works correctly - with PersistanceClass.query().only() all the fields will still be retrieved. + # 1. Migrate ActionExecutionDB objects + res_count = ActionExecutionDB.objects( + __raw__={ + "status": { + "$in": LIVEACTION_COMPLETED_STATES + + [LIVEACTION_STATUS_PAUSED, LIVEACTION_STATUS_PENDING], + }, + }, + start_timestamp__gte=start_dt, + start_timestamp__lte=end_dt, + ).as_pymongo() + for item in res_count: + try: + ActionExecutionDB.objects(__raw__={"_id": item["_id"]}).update( + __raw__={"$set": {"liveaction_id": item["liveaction"]["id"]}} + ) + except KeyError: + pass + + ActionExecutionDB.objects( + __raw__={ + "status": { + "$in": LIVEACTION_COMPLETED_STATES + + [LIVEACTION_STATUS_PAUSED, LIVEACTION_STATUS_PENDING], + }, + }, + start_timestamp__gte=start_dt, + start_timestamp__lte=end_dt, + ).update(__raw__={"$unset": {"liveaction": 1}}) + + objects_count = res_count.count() + + print("migrated %s ActionExecutionDB objects" % (objects_count)) + print("") + + +def _register_cli_opts(): + cfg.CONF.register_cli_opt( + cfg.BoolOpt( + "yes", + short="y", + required=False, + default=False, + ) + ) + + # We default to past 30 days. Keep in mind that using longer period may take a long time in + # case there are many objects in the database. + now_dt = datetime.datetime.utcnow() + start_dt = now_dt - datetime.timedelta(days=30) + + cfg.CONF.register_cli_opt( + cfg.StrOpt( + "start-dt", + required=False, + help=( + "Start cut off ISO UTC iso date time string for objects which will be migrated. " + "Defaults to now - 30 days." + "Example value: 2020-03-13T19:01:27Z" + ), + default=start_dt.strftime("%Y-%m-%dT%H:%M:%SZ"), + ) + ) + cfg.CONF.register_cli_opt( + cfg.StrOpt( + "end-dt", + required=False, + help=( + "End cut off UTC ISO date time string for objects which will be migrated." + "Defaults to now." + "Example value: 2020-03-13T19:01:27Z" + ), + default=now_dt.strftime("%Y-%m-%dT%H:%M:%SZ"), + ) + ) + + +def migrate_objects( + start_dt: datetime.datetime, end_dt: datetime.datetime, display_prompt: bool = True +) -> None: + start_dt_str = start_dt.strftime("%Y-%m-%d %H:%M:%S") + end_dt_str = end_dt.strftime("%Y-%m-%d %H:%M:%S") + + print("StackStorm v3.9 database field data migration script\n") + + if display_prompt: + input( + "Will migrate objects with creation date between %s UTC and %s UTC.\n\n" + "You are strongly recommended to create database backup before proceeding.\n\n" + "Depending on the number of the objects in the database, " + "migration may take multiple hours or more. You are recommended to start the " + "script in a screen session, tmux or similar. \n\n" + "To proceed with the migration, press enter and to cancel it, press CTRL+C.\n" + % (start_dt_str, end_dt_str) + ) + print("") + + print( + "Migrating affected database objects between %s and %s" + % (start_dt_str, end_dt_str) + ) + print("") + + start_ts = int(time.time()) + migrate_executions(start_dt=start_dt, end_dt=end_dt) + end_ts = int(time.time()) + + duration = end_ts - start_ts + + print( + "SUCCESS: All database objects migrated successfully (duration: %s seconds)." + % (duration) + ) + + +def main(): + _register_cli_opts() + + config.parse_args() + db_setup() + + start_dt = isotime.parse(cfg.CONF.start_dt) + + if cfg.CONF.end_dt == "now": + end_dt = datetime.datetime.utcnow() + end_dt = end_dt.replace(tzinfo=datetime.timezone.utc) + else: + end_dt = isotime.parse(cfg.CONF.end_dt) + + try: + migrate_objects( + start_dt=start_dt, end_dt=end_dt, display_prompt=not cfg.CONF.yes + ) + exit_code = 0 + except Exception as e: + print("ABORTED: Objects migration aborted on first failure: %s" % (str(e))) + traceback.print_exc() + exit_code = 1 + + db_teardown() + sys.exit(exit_code) + + +if __name__ == "__main__": + main() diff --git a/st2common/st2common/config.py b/st2common/st2common/config.py index c88955e4bb..c204fa0d93 100644 --- a/st2common/st2common/config.py +++ b/st2common/st2common/config.py @@ -20,6 +20,7 @@ from oslo_config import cfg +from st2common.constants.compression import ZSTANDARD_COMPRESS, VALID_COMPRESS from st2common.constants.system import VERSION_STRING from st2common.constants.system import DEFAULT_CONFIG_FILE_PATH from st2common.constants.runners import PYTHON_RUNNER_DEFAULT_LOG_LEVEL @@ -243,6 +244,14 @@ def register_opts(ignore_errors=False): "By default, it use SCRAM-SHA-1 with MongoDB 3.0 and later, " "MONGODB-CR (MongoDB Challenge Response protocol) for older servers.", ), + cfg.StrOpt( + "parameter_result_compression", + default=ZSTANDARD_COMPRESS, + required=True, + choices=VALID_COMPRESS, + help="compression for parameter and result storage in liveaction and " + "execution models", + ), cfg.StrOpt( "compressors", default="", diff --git a/st2common/st2common/constants/compression.py b/st2common/st2common/constants/compression.py new file mode 100644 index 0000000000..edb3581cf9 --- /dev/null +++ b/st2common/st2common/constants/compression.py @@ -0,0 +1,89 @@ +# Copyright 2020 The StackStorm Authors. +# Copyright 2019 Extreme Networks, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Mongoengine is licensed under MIT. +""" + + +import enum +from oslo_config import cfg +import zstandard + +ZSTANDARD_COMPRESS = "zstandard" +NO_COMPRESSION = "none" + +VALID_COMPRESS = [ZSTANDARD_COMPRESS, NO_COMPRESSION] + + +class JSONDictFieldCompressionAlgorithmEnum(enum.Enum): + """ + Enum which represents compression algorithm (if any) used for a specific JSONDictField value. + """ + + ZSTANDARD = b"z" + + +VALID_JSON_DICT_COMPRESSION_ALGORITHMS = [ + JSONDictFieldCompressionAlgorithmEnum.ZSTANDARD.value, +] + + +def zstandard_compress(data): + data = ( + JSONDictFieldCompressionAlgorithmEnum.ZSTANDARD.value + + zstandard.ZstdCompressor().compress(data) + ) + return data + + +def zstandard_uncompress(data): + data = zstandard.ZstdDecompressor().decompress(data) + return data + + +MAP_COMPRESS = { + ZSTANDARD_COMPRESS: zstandard_compress, +} + + +MAP_UNCOMPRESS = { + JSONDictFieldCompressionAlgorithmEnum.ZSTANDARD.value: zstandard_uncompress, +} + + +def uncompress(value: bytes): + data = value + try: + uncompression_header = value[0:1] + uncompression_method = MAP_UNCOMPRESS.get(uncompression_header, False) + if uncompression_method: # skip if no compress + data = uncompression_method(value[1:]) + # will need to add additional exceptions if additonal compression methods + # are added in the future; please do not catch the general exception here. + except zstandard.ZstdError: + # skip if already a byte string and not zstandard compressed + pass + return data + + +def compress(value: bytes): + data = value + parameter_result_compression = cfg.CONF.database.parameter_result_compression + compression_method = MAP_COMPRESS.get(parameter_result_compression, False) + # none is not mapped at all so has no compression method + if compression_method: + data = compression_method(value) + return data diff --git a/st2common/st2common/fields.py b/st2common/st2common/fields.py index 0e94f11f85..f6a821aa73 100644 --- a/st2common/st2common/fields.py +++ b/st2common/st2common/fields.py @@ -27,7 +27,6 @@ import datetime import calendar -import enum import weakref import orjson @@ -38,6 +37,10 @@ from mongoengine.base.datastructures import mark_key_as_changed_wrapper from mongoengine.common import _import_class +from st2common.constants.compression import ( + compress as compress_function, + uncompress as uncompress_function, +) from st2common.util import date as date_utils from st2common.util import mongoescape @@ -49,34 +52,6 @@ JSON_DICT_FIELD_DELIMITER = b":" -class JSONDictFieldCompressionAlgorithmEnum(enum.Enum): - """ - Enum which represents compression algorithm (if any) used for a specific JSONDictField value. - """ - - NONE = b"n" - ZSTANDARD = b"z" - - -class JSONDictFieldSerializationFormatEnum(enum.Enum): - """ - Enum which represents serialization format used for a specific JSONDictField value. - """ - - ORJSON = b"o" - - -VALID_JSON_DICT_COMPRESSION_ALGORITHMS = [ - JSONDictFieldCompressionAlgorithmEnum.NONE.value, - JSONDictFieldCompressionAlgorithmEnum.ZSTANDARD.value, -] - - -VALID_JSON_DICT_SERIALIZATION_FORMATS = [ - JSONDictFieldSerializationFormatEnum.ORJSON.value, -] - - class ComplexDateTimeField(LongField): """ Date time field which handles microseconds exactly and internally stores @@ -331,13 +306,14 @@ def _mark_as_changed(self, key=None): class JSONDictField(BinaryField): """ - Custom field types which stores dictionary as JSON serialized strings. + Custom field types which stores dictionary as zstandard compressed JSON serialized strings. - This is done because storing large objects as JSON serialized strings is much more fficient + This is done because storing large objects as compressed JSON serialized + strings is much more efficient on the serialize and unserialize paths compared to used EscapedDictField which needs to escape all the special values ($, .). - Only downside is that to MongoDB those values are plain raw strings which means you can't query + Only downside is that to MongoDB those values are compressed plain raw strings which means you can't query on actual dictionary field values. That's not an issue for us, because in places where we use it, those values are already treated as plain binary blobs to the database layer and we never directly query on those field values. @@ -358,25 +334,11 @@ class JSONDictField(BinaryField): IMPLEMENTATION DETAILS: - If header is used, values are stored in the following format: - ::. - - For example: - n:o:... - No compression, (or)json serialization - z:o:... - Zstandard compression, (or)json serialization - If header is not used, value is stored as a serialized JSON string of the input dictionary. """ def __init__(self, *args, **kwargs): - # True if we should use field header which is more future proof approach and also allows - # us to support optional per-field compression, etc. - # This option is only exposed so we can benchmark different approaches and how much overhead - # using a header adds. - self.use_header = kwargs.pop("use_header", False) - self.compression_algorithm = kwargs.pop("compression_algorithm", "none") - super(JSONDictField, self).__init__(*args, **kwargs) def to_mongo(self, value): @@ -403,11 +365,6 @@ def validate(self, value): def parse_field_value(self, value: Optional[Union[bytes, dict]]) -> dict: """ Parse provided binary field value and return parsed value (dictionary). - - For example: - - - (n, o, ...) - no compression, data is serialized using orjson - - (z, o, ...) - zstandard compression, data is serialized using orjson """ if not value: return self.default @@ -416,45 +373,10 @@ def parse_field_value(self, value: Optional[Union[bytes, dict]]) -> dict: # Already deserializaed return value - if not self.use_header: - return orjson.loads(value) - - split = value.split(JSON_DICT_FIELD_DELIMITER, 2) - - if len(split) != 3: - raise ValueError( - "Expected 3 values when splitting field value, got %s" % (len(split)) - ) - - compression_algorithm = split[0] - serialization_format = split[1] - data = split[2] - - if compression_algorithm not in VALID_JSON_DICT_COMPRESSION_ALGORITHMS: - raise ValueError( - "Invalid or unsupported value for compression algorithm header " - "value: %s" % (compression_algorithm) - ) - - if serialization_format not in VALID_JSON_DICT_SERIALIZATION_FORMATS: - raise ValueError( - "Invalid or unsupported value for serialization format header " - "value: %s" % (serialization_format) - ) - - if ( - compression_algorithm - == JSONDictFieldCompressionAlgorithmEnum.ZSTANDARD.value - ): - # NOTE: At this point zstandard is only test dependency - import zstandard - - data = zstandard.ZstdDecompressor().decompress(data) - - data = orjson.loads(data) + data = orjson.loads(uncompress_function(value)) return data - def _serialize_field_value(self, value: dict) -> bytes: + def _serialize_field_value(self, value: dict, compress=True) -> bytes: """ Serialize and encode the provided field value. """ @@ -474,21 +396,10 @@ def default(obj): return list(obj) raise TypeError - if not self.use_header: - return orjson.dumps(value, default=default) - data = orjson.dumps(value, default=default) - - if self.compression_algorithm == "zstandard": - # NOTE: At this point zstandard is only test dependency - import zstandard - - compression_header = JSONDictFieldCompressionAlgorithmEnum.ZSTANDARD - data = zstandard.ZstdCompressor().compress(data) - else: - compression_header = JSONDictFieldCompressionAlgorithmEnum.NONE - - return compression_header.value + b":" + b"o:" + data + if compress: + data = compress_function(data) + return data def __get__(self, instance, owner): """ @@ -522,11 +433,6 @@ class JSONDictEscapedFieldCompatibilityField(JSONDictField): def to_mongo(self, value): if isinstance(value, bytes): # Already serialized - if value[0] == b"{" and self.use_header: - # Serialized, but doesn't contain header prefix, add it (assume migration from - # format without a header) - return "n:o:" + value - return value if not isinstance(value, dict): diff --git a/st2common/st2common/models/api/action.py b/st2common/st2common/models/api/action.py index dc18ba02cd..f82220950b 100644 --- a/st2common/st2common/models/api/action.py +++ b/st2common/st2common/models/api/action.py @@ -34,6 +34,7 @@ from st2common.models.db.runner import RunnerTypeDB from st2common.constants.action import LIVEACTION_STATUSES from st2common.models.system.common import ResourceReference +from st2common.fields import JSONDictEscapedFieldCompatibilityField __all__ = [ @@ -440,9 +441,23 @@ class LiveActionAPI(BaseAPI): }, "additionalProperties": False, } - skip_unescape_field_names = [ - "result", - ] + skip_unescape_field_names = ["result", "parameters"] + + @classmethod + def convert_raw(cls, doc, raw_values): + """ + override this class to + convert any raw byte values into dict + + :param doc: dict + :param raw_values: dict[field]:bytestring + """ + + for field_name, field_value in raw_values.items(): + doc[ + field_name + ] = JSONDictEscapedFieldCompatibilityField().parse_field_value(field_value) + return doc @classmethod def from_model(cls, model, mask_secrets=False): @@ -451,7 +466,6 @@ def from_model(cls, model, mask_secrets=False): doc["start_timestamp"] = isotime.format(model.start_timestamp, offset=False) if model.end_timestamp: doc["end_timestamp"] = isotime.format(model.end_timestamp, offset=False) - if getattr(model, "notify", None): doc["notify"] = NotificationsHelper.from_model(model.notify) diff --git a/st2common/st2common/models/api/base.py b/st2common/st2common/models/api/base.py index 6cdb16feef..445fe39023 100644 --- a/st2common/st2common/models/api/base.py +++ b/st2common/st2common/models/api/base.py @@ -22,6 +22,7 @@ from st2common.util import mongoescape as util_mongodb from st2common import log as logging +from st2common.models.db.stormbase import EscapedDynamicField, EscapedDictField __all__ = ["BaseAPI", "APIUIDMixin"] @@ -86,6 +87,12 @@ def validate(self): @classmethod def _from_model(cls, model, mask_secrets=False): + unescape_fields = [ + k + for k, v in model._fields.items() + if type(v) in [EscapedDynamicField, EscapedDictField] + ] + unescape_fields = set(unescape_fields) - set(cls.skip_unescape_field_names) doc = model.to_mongo() if "_id" in doc: @@ -94,32 +101,35 @@ def _from_model(cls, model, mask_secrets=False): # Special case for models which utilize JSONDictField - there is no need to escape those # fields since it contains a JSON string and not a dictionary which doesn't need to be # mongo escaped. Skipping this step here substantially speeds things up for that field. - - # Right now we do this here manually for all those fields types but eventually we should - # refactor the code to just call unescape chars on escaped fields - more generic and - # faster. raw_values = {} - for field_name in cls.skip_unescape_field_names: if isinstance(doc.get(field_name, None), bytes): raw_values[field_name] = doc.pop(field_name) - - # TODO (Tomaz): In general we really shouldn't need to call unescape chars on the whole doc, - # but just on the EscapedDict and EscapedDynamicField fields - doing it on the whole doc - # level is slow and not necessary! - doc = util_mongodb.unescape_chars(doc) - - # Now add the JSON string field value which shouldn't be escaped back. - # We don't JSON parse the field value here because that happens inside the model specific - # "from_model()" method where we also parse and convert all the other field values. - for field_name, field_value in raw_values.items(): - doc[field_name] = field_value + for key in unescape_fields: + if key in doc.keys(): + doc[key] = util_mongodb.unescape_chars(doc[key]) + # convert raw fields and add back ; no need to unescape + doc = cls.convert_raw(doc, raw_values) if mask_secrets and cfg.CONF.log.mask_secrets: doc = model.mask_secrets(value=doc) return doc + @classmethod + def convert_raw(cls, doc, raw_values): + """ + override this class to + convert any raw byte values into dict + you can also use this to fix any other fields that need 'fixing' + + :param doc: dict + :param raw_values: dict[field]:bytestring + """ + for field_name, field_value in raw_values.items(): + doc[field_name] = field_value + return doc + @classmethod def from_model(cls, model, mask_secrets=False): """ diff --git a/st2common/st2common/models/api/execution.py b/st2common/st2common/models/api/execution.py index 76aa9fbdf8..019b367eed 100644 --- a/st2common/st2common/models/api/execution.py +++ b/st2common/st2common/models/api/execution.py @@ -29,6 +29,7 @@ from st2common.models.api.action import RunnerTypeAPI, ActionAPI, LiveActionAPI from st2common import log as logging from st2common.util.deep_copy import fast_deepcopy_dict +from st2common.fields import JSONDictEscapedFieldCompatibilityField __all__ = ["ActionExecutionAPI", "ActionExecutionOutputAPI"] @@ -145,9 +146,7 @@ class ActionExecutionAPI(BaseAPI): }, "additionalProperties": False, } - skip_unescape_field_names = [ - "result", - ] + skip_unescape_field_names = ["result", "parameters"] @classmethod def from_model(cls, model, mask_secrets=False): @@ -171,6 +170,24 @@ def from_model(cls, model, mask_secrets=False): attrs = {attr: value for attr, value in six.iteritems(doc) if value} return cls(**attrs) + @classmethod + def convert_raw(cls, doc, raw_values): + """ + override this class to + convert any raw byte values into dict + Now add the JSON string field value which shouldn't be escaped back. + We don't JSON parse the field value here because that happens inside the model specific + "from_model()" method where we also parse and convert all the other field values. + :param doc: dict + :param raw_values: dict[field]:bytestring + """ + + for field_name, field_value in raw_values.items(): + doc[ + field_name + ] = JSONDictEscapedFieldCompatibilityField().parse_field_value(field_value) + return doc + @classmethod def to_model(cls, instance): values = {} diff --git a/st2common/st2common/models/db/execution.py b/st2common/st2common/models/db/execution.py index 0de35a5c31..c0ffb4ab8f 100644 --- a/st2common/st2common/models/db/execution.py +++ b/st2common/st2common/models/db/execution.py @@ -61,7 +61,8 @@ class ActionExecutionDB(stormbase.StormFoundationDB): end_timestamp = ComplexDateTimeField( help_text="The timestamp when the liveaction has finished." ) - parameters = stormbase.EscapedDynamicField( + action = stormbase.EscapedDictField(required=True) + parameters = JSONDictEscapedFieldCompatibilityField( default={}, help_text="The key-value pairs passed as to the action runner & action.", ) @@ -72,6 +73,15 @@ class ActionExecutionDB(stormbase.StormFoundationDB): context = me.DictField( default={}, help_text="Contextual information on the action execution." ) + delay = me.IntField(min_value=0) + + # diff from liveaction + runner = stormbase.EscapedDictField(required=True) + trigger = stormbase.EscapedDictField() + trigger_type = stormbase.EscapedDictField() + trigger_instance = stormbase.EscapedDictField() + rule = stormbase.EscapedDictField() + result_size = me.IntField(default=0, help_text="Serialized result size in bytes") parent = me.StringField() children = me.ListField(field=me.StringField()) log = me.ListField(field=me.DictField()) @@ -115,7 +125,6 @@ def mask_secrets(self, value): :return: result: action execution object with masked secret paramters in input and output schema. :rtype: result: ``dict`` """ - result = copy.deepcopy(value) liveaction = result["liveaction"] diff --git a/st2common/st2common/models/db/liveaction.py b/st2common/st2common/models/db/liveaction.py index aef52462a6..99cad1982d 100644 --- a/st2common/st2common/models/db/liveaction.py +++ b/st2common/st2common/models/db/liveaction.py @@ -54,11 +54,7 @@ class LiveActionDB(stormbase.StormFoundationDB): action = me.StringField( required=True, help_text="Reference to the action that has to be executed." ) - action_is_workflow = me.BooleanField( - default=False, - help_text="A flag indicating whether the referenced action is a workflow.", - ) - parameters = stormbase.EscapedDynamicField( + parameters = JSONDictEscapedFieldCompatibilityField( default={}, help_text="The key-value pairs passed as to the action runner & execution.", ) @@ -68,19 +64,24 @@ class LiveActionDB(stormbase.StormFoundationDB): context = me.DictField( default={}, help_text="Contextual information on the action execution." ) + delay = me.IntField( + min_value=0, + help_text="How long (in milliseconds) to delay the execution before scheduling.", + ) + # diff from action execution + action_is_workflow = me.BooleanField( + default=False, + help_text="A flag indicating whether the referenced action is a workflow.", + ) callback = me.DictField( default={}, help_text="Callback information for the on completion of action execution.", ) + notify = me.EmbeddedDocumentField(NotificationSchema) runner_info = me.DictField( default={}, help_text="Information about the runner which executed this live action (hostname, pid).", ) - notify = me.EmbeddedDocumentField(NotificationSchema) - delay = me.IntField( - min_value=0, - help_text="How long (in milliseconds) to delay the execution before scheduling.", - ) meta = { "indexes": [ @@ -114,6 +115,23 @@ def mask_secrets(self, value): result["parameters"] = mask_secret_parameters( parameters=execution_parameters, secret_parameters=secret_parameters ) + if result.get("action", "") == "st2.inquiry.respond": + # In this case, this execution is just a plain python action, not + # an inquiry, so we don't natively have a handle on the response + # schema. + # + # To prevent leakage, we can just mask all response fields. + # + # Note: The 'string' type in secret_parameters doesn't matter, + # it's just a placeholder to tell mask_secret_parameters() + # that this parameter is indeed a secret parameter and to + # mask it. + result["parameters"]["response"] = mask_secret_parameters( + parameters=result["parameters"]["response"], + secret_parameters={ + p: "string" for p in result["parameters"]["response"] + }, + ) return result def get_masked_parameters(self): diff --git a/st2common/st2common/models/db/stormbase.py b/st2common/st2common/models/db/stormbase.py index e75e381941..67c93f42a1 100644 --- a/st2common/st2common/models/db/stormbase.py +++ b/st2common/st2common/models/db/stormbase.py @@ -108,10 +108,8 @@ def to_serializable_dict(self, mask_secrets=False): v = json_decode(v.to_json()) serializable_dict[k] = v - if mask_secrets and cfg.CONF.log.mask_secrets: serializable_dict = self.mask_secrets(value=serializable_dict) - return serializable_dict diff --git a/st2common/st2common/services/executions.py b/st2common/st2common/services/executions.py index 80706e8f79..4b984cd587 100644 --- a/st2common/st2common/services/executions.py +++ b/st2common/st2common/services/executions.py @@ -225,7 +225,7 @@ def update_execution(liveaction_db, publish=True, set_result_size=False): with Timer(key="action.executions.calculate_result_size"): result_size = len( ActionExecutionDB.result._serialize_field_value( - liveaction_db.result + value=liveaction_db.result, compress=False ) ) kw["set__result_size"] = result_size diff --git a/st2common/st2common/services/policies.py b/st2common/st2common/services/policies.py index 46e24ce290..c4aa1ac87e 100644 --- a/st2common/st2common/services/policies.py +++ b/st2common/st2common/services/policies.py @@ -15,6 +15,9 @@ from __future__ import absolute_import +import sys +import traceback + from st2common.constants import action as ac_const from st2common import log as logging from st2common.persistence import policy as pc_db_access @@ -58,9 +61,17 @@ def apply_pre_run_policies(lv_ac_db): LOG.info(message % (policy_db.ref, policy_db.policy_type, str(lv_ac_db.id))) lv_ac_db = driver.apply_before(lv_ac_db) except: - message = 'An exception occurred while applying policy "%s" (%s) for liveaction "%s".' + _, ex, tb = sys.exc_info() + traceback_var = "".join(traceback.format_tb(tb, 20)) + message = 'An exception occurred while applying policy "%s" (%s) for liveaction "%s". traceback "%s"' LOG.exception( - message % (policy_db.ref, policy_db.policy_type, str(lv_ac_db.id)) + message + % ( + policy_db.ref, + policy_db.policy_type, + str(lv_ac_db.id), + traceback_var, + ) ) if lv_ac_db.status == ac_const.LIVEACTION_STATUS_DELAYED: diff --git a/st2common/st2common/services/workflows.py b/st2common/st2common/services/workflows.py index b84671f8b1..2aa50cee5f 100644 --- a/st2common/st2common/services/workflows.py +++ b/st2common/st2common/services/workflows.py @@ -248,6 +248,7 @@ def request(wf_def, ac_ex_db, st2_ctx, notify_cfg=None): ) # Instantiate the workflow conductor. + LOG.info("action_params: " + str(action_params)) conductor_params = {"inputs": action_params, "context": st2_ctx} conductor = conducting.WorkflowConductor(wf_spec, **conductor_params) @@ -666,7 +667,7 @@ def request_task_execution(wf_ex_db, st2_ctx, task_ex_req): except Exception as e: msg = 'Failed action execution(s) for task "%s", route "%s".' msg = msg % (task_id, str(task_route)) - LOG.exception(msg) + LOG.exception(msg, exc_info=True) msg = "%s %s: %s" % (msg, type(e).__name__, six.text_type(e)) update_progress(wf_ex_db, msg, severity="error", log=False) msg = "%s: %s" % (type(e).__name__, six.text_type(e)) @@ -1189,7 +1190,7 @@ def request_next_tasks(wf_ex_db, task_ex_id=None): update_progress( wf_ex_db, "%s %s" % (msg, str(e)), severity="error", log=False ) - LOG.exception(msg) + LOG.exception(msg, exc_info=True) fail_workflow_execution(str(wf_ex_db.id), e, task=task) return diff --git a/st2common/tests/unit/test_db_fields.py b/st2common/tests/unit/test_db_fields.py index 9abd587fb5..3258aa158f 100644 --- a/st2common/tests/unit/test_db_fields.py +++ b/st2common/tests/unit/test_db_fields.py @@ -20,9 +20,9 @@ import calendar import mock +from oslo_config import cfg import unittest2 import orjson -import zstandard # pytest: make sure monkey_patching happens before importing mongoengine from st2common.util.monkey_patch import monkey_patch @@ -37,8 +37,6 @@ from st2common.models.db import MongoDBAccess from st2common.fields import JSONDictField from st2common.fields import JSONDictEscapedFieldCompatibilityField -from st2common.fields import JSONDictFieldCompressionAlgorithmEnum -from st2common.fields import JSONDictFieldSerializationFormatEnum from st2tests import DbTestCase @@ -79,6 +77,14 @@ class ModelWithJSONDictFieldDB(stormbase.StormFoundationDB): class JSONDictFieldTestCase(unittest2.TestCase): + def setUp(self): + # NOTE: It's important we re-establish a connection on each setUp + cfg.CONF.reset() + + def tearDown(self): + # NOTE: It's important we disconnect here otherwise tests will fail + cfg.CONF.reset() + def test_set_to_mongo(self): field = JSONDictField(use_header=False) result = field.to_mongo({"test": {1, 2}}) @@ -89,12 +95,27 @@ def test_header_set_to_mongo(self): result = field.to_mongo({"test": {1, 2}}) self.assertTrue(isinstance(result, bytes)) - def test_to_mongo(self): + def test_to_mongo_to_python_none(self): + cfg.CONF.set_override( + name="parameter_result_compression", group="database", override="none" + ) + field = JSONDictField(use_header=False) + result = field.to_mongo(MOCK_DATA_DICT) + + self.assertTrue(isinstance(result, bytes)) + result = field.to_python(result) + self.assertEqual(result, MOCK_DATA_DICT) + + def test_to_mongo_zstandard(self): + cfg.CONF.set_override( + name="parameter_result_compression", group="database", override="zstandard" + ) field = JSONDictField(use_header=False) result = field.to_mongo(MOCK_DATA_DICT) self.assertTrue(isinstance(result, bytes)) - self.assertEqual(result, orjson.dumps(MOCK_DATA_DICT)) + result = field.to_python(result) + self.assertEqual(result, MOCK_DATA_DICT) def test_to_python(self): field = JSONDictField(use_header=False) @@ -147,75 +168,13 @@ def test_parse_field_value(self): self.assertEqual(result, {"c": "d"}) -class JSONDictFieldTestCaseWithHeader(unittest2.TestCase): - def test_to_mongo_no_compression(self): - field = JSONDictField(use_header=True) - - result = field.to_mongo(MOCK_DATA_DICT) - self.assertTrue(isinstance(result, bytes)) - - split = result.split(b":", 2) - self.assertEqual(split[0], JSONDictFieldCompressionAlgorithmEnum.NONE.value) - self.assertEqual(split[1], JSONDictFieldSerializationFormatEnum.ORJSON.value) - self.assertEqual(orjson.loads(split[2]), MOCK_DATA_DICT) - - parsed_value = field.parse_field_value(result) - self.assertEqual(parsed_value, MOCK_DATA_DICT) - - def test_to_mongo_zstandard_compression(self): - field = JSONDictField(use_header=True, compression_algorithm="zstandard") - - result = field.to_mongo(MOCK_DATA_DICT) - self.assertTrue(isinstance(result, bytes)) - - split = result.split(b":", 2) - self.assertEqual( - split[0], JSONDictFieldCompressionAlgorithmEnum.ZSTANDARD.value - ) - self.assertEqual(split[1], JSONDictFieldSerializationFormatEnum.ORJSON.value) - self.assertEqual( - orjson.loads(zstandard.ZstdDecompressor().decompress(split[2])), - MOCK_DATA_DICT, - ) - - parsed_value = field.parse_field_value(result) - self.assertEqual(parsed_value, MOCK_DATA_DICT) - - def test_to_python_no_compression(self): - field = JSONDictField(use_header=True) - - serialized_data = field.to_mongo(MOCK_DATA_DICT) - - self.assertTrue(isinstance(serialized_data, bytes)) - split = serialized_data.split(b":", 2) - self.assertEqual(split[0], JSONDictFieldCompressionAlgorithmEnum.NONE.value) - self.assertEqual(split[1], JSONDictFieldSerializationFormatEnum.ORJSON.value) - - desserialized_data = field.to_python(serialized_data) - self.assertEqual(desserialized_data, MOCK_DATA_DICT) - - def test_to_python_zstandard_compression(self): - field = JSONDictField(use_header=True, compression_algorithm="zstandard") - - serialized_data = field.to_mongo(MOCK_DATA_DICT) - self.assertTrue(isinstance(serialized_data, bytes)) - - split = serialized_data.split(b":", 2) - self.assertEqual( - split[0], JSONDictFieldCompressionAlgorithmEnum.ZSTANDARD.value - ) - self.assertEqual(split[1], JSONDictFieldSerializationFormatEnum.ORJSON.value) - - desserialized_data = field.to_python(serialized_data) - self.assertEqual(desserialized_data, MOCK_DATA_DICT) - - class JSONDictEscapedFieldCompatibilityFieldTestCase(DbTestCase): def test_to_mongo(self): field = JSONDictEscapedFieldCompatibilityField(use_header=False) result_to_mongo_1 = field.to_mongo(MOCK_DATA_DICT) - self.assertEqual(result_to_mongo_1, orjson.dumps(MOCK_DATA_DICT)) + self.assertTrue(isinstance(result_to_mongo_1, bytes)) + self.assertEqual(result_to_mongo_1[0:1], b"z") # Already serialized result_to_mongo_2 = field.to_mongo(MOCK_DATA_DICT) @@ -275,7 +234,12 @@ def test_existing_db_value_is_using_escaped_dict_field_compatibility(self): self.assertEqual(len(pymongo_result), 1) self.assertEqual(pymongo_result[0]["_id"], inserted_model_db.id) self.assertTrue(isinstance(pymongo_result[0]["result"], bytes)) - self.assertEqual(orjson.loads(pymongo_result[0]["result"]), expected_data) + + result = pymongo_result[0]["result"] + + field = JSONDictField(use_header=False) + result = field.to_python(result) + self.assertEqual(result, expected_data) self.assertEqual(pymongo_result[0]["counter"], 1) def test_field_state_changes_are_correctly_detected_add_or_update_method(self): diff --git a/st2tests/integration/orquesta/test_wiring_error_handling.py b/st2tests/integration/orquesta/test_wiring_error_handling.py index 130a68c7c5..1c906ce20d 100644 --- a/st2tests/integration/orquesta/test_wiring_error_handling.py +++ b/st2tests/integration/orquesta/test_wiring_error_handling.py @@ -66,7 +66,27 @@ def test_inspection_error(self): ex = self._execute_workflow("examples.orquesta-fail-inspection") ex = self._wait_for_completion(ex) self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_FAILED) - self.assertDictEqual(ex.result, {"errors": expected_errors, "output": None}) + self.assertEqual(len(ex.result["errors"]), len(expected_errors)) + for i in range(len(ex.result["errors"])): + self.assertEqual(ex.result["errors"][i]["type"], expected_errors[i]["type"]) + self.assertEqual( + ex.result["errors"][i]["message"], expected_errors[i]["message"] + ) + self.assertEqual( + ex.result["errors"][i]["schema_path"], expected_errors[i]["schema_path"] + ) + self.assertEqual( + ex.result["errors"][i]["spec_path"], expected_errors[i]["spec_path"] + ) + if "language" in expected_errors[i]: + self.assertEqual( + ex.result["errors"][i]["language"], expected_errors[i]["language"] + ) + if "expression" in expected_errors[i]: + self.assertEqual( + ex.result["errors"][i]["expression"], + expected_errors[i]["expression"], + ) def test_input_error(self): expected_errors = [ @@ -227,7 +247,18 @@ def test_task_content_errors(self): ex = self._execute_workflow("examples.orquesta-fail-inspection-task-contents") ex = self._wait_for_completion(ex) self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_FAILED) - self.assertDictEqual(ex.result, {"errors": expected_errors, "output": None}) + self.assertEqual(len(ex.result["errors"]), len(expected_errors)) + for i in range(len(ex.result["errors"])): + self.assertEqual(ex.result["errors"][i]["type"], expected_errors[i]["type"]) + self.assertEqual( + ex.result["errors"][i]["message"], expected_errors[i]["message"] + ) + self.assertEqual( + ex.result["errors"][i]["schema_path"], expected_errors[i]["schema_path"] + ) + self.assertEqual( + ex.result["errors"][i]["spec_path"], expected_errors[i]["spec_path"] + ) def test_remediate_then_fail(self): expected_errors = [ diff --git a/st2tests/st2tests/fixtures/descendants/liveactions/liveaction_fake.yaml b/st2tests/st2tests/fixtures/descendants/liveactions/liveaction_fake.yaml new file mode 100644 index 0000000000..b933dedf63 --- /dev/null +++ b/st2tests/st2tests/fixtures/descendants/liveactions/liveaction_fake.yaml @@ -0,0 +1,5 @@ +--- +action: local +name: "fake" +id: 54c6b6d60640fd4f5354e74a +status: succeeded