diff --git a/q2_mystery_stew/generators/__init__.py b/q2_mystery_stew/generators/__init__.py index 15d8938..7526806 100644 --- a/q2_mystery_stew/generators/__init__.py +++ b/q2_mystery_stew/generators/__init__.py @@ -10,9 +10,10 @@ primitive_union_params) from .metadata import metadata_params from .artifacts import artifact_params -from .collections import list_paramgen, set_paramgen +from .collections import list_paramgen, collection_paramgen from .actions import (generate_single_type_methods, - generate_multiple_output_methods) + generate_multiple_output_methods, + generate_output_collection_methods) from .base import ParamTemplate, ActionTemplate, ParamSpec, Invocation BASIC_GENERATORS = { @@ -24,7 +25,8 @@ 'strings': string_params, 'primitive_unions': primitive_union_params, } -FILTERS = {*BASIC_GENERATORS.keys(), 'collections', 'typemaps', 'outputs'} +FILTERS = {*BASIC_GENERATORS.keys(), 'collections', 'typemaps', 'outputs', + 'output_collections'} from .typemaps import generate_typemap_methods # noqa: E402 @@ -38,23 +40,24 @@ def should_add(filter_): add_collections = should_add('collections') lists = [] - sets = [] + collections = [] for key, generator in BASIC_GENERATORS.items(): if should_add(key): selected_generators.append(generator()) if add_collections and key != 'metadata': lists.append(list_paramgen(generator())) - sets.append(set_paramgen(generator())) + collections.append(collection_paramgen(generator())) selected_generators.extend(lists) - selected_generators.extend(sets) + selected_generators.extend(collections) return selected_generators __all__ = ['int_params', 'float_params', 'string_params', 'bool_params', 'primitive_union_params', 'metadata_params', 'artifact_params', - 'list_paramgen', 'set_paramgen', 'generate_single_type_methods', - 'generate_multiple_output_methods', 'generate_typemap_methods', + 'list_paramgen', 'collection_paramgen', + 'generate_single_type_methods', 'generate_multiple_output_methods', + 'generate_output_collection_methods', 'generate_typemap_methods', 'BASIC_GENERATORS', 'FILTERS', 'get_param_generators', 'ParamTemplate', 'ParamSpec', 'ActionTemplate', 'Invocation'] diff --git a/q2_mystery_stew/generators/actions.py b/q2_mystery_stew/generators/actions.py index c9fc252..486e437 100644 --- a/q2_mystery_stew/generators/actions.py +++ b/q2_mystery_stew/generators/actions.py @@ -8,6 +8,7 @@ from collections import deque +from qiime2.core.type import Collection from qiime2.sdk.util import is_metadata_type, is_semantic_type from q2_mystery_stew.type import EchoOutput @@ -71,3 +72,28 @@ def generate_multiple_output_methods(): parameter_specs={}, registered_outputs=qiime_outputs, invocation_domain=[Invocation({}, qiime_outputs)]) + + +def generate_output_collection_methods(): + action_id = 'collection_only' + qiime_outputs = [('output', Collection[EchoOutput])] + yield ActionTemplate(action_id=action_id, + parameter_specs={}, + registered_outputs=qiime_outputs, + invocation_domain=[Invocation({}, qiime_outputs)]) + + action_id = 'collection_first' + qiime_outputs = [('output_collection', Collection[EchoOutput]), + ('output', EchoOutput)] + yield ActionTemplate(action_id=action_id, + parameter_specs={}, + registered_outputs=qiime_outputs, + invocation_domain=[Invocation({}, qiime_outputs)]) + + action_id = 'collection_second' + qiime_outputs = [('output', EchoOutput), + ('output_collection', Collection[EchoOutput])] + yield ActionTemplate(action_id=action_id, + parameter_specs={}, + registered_outputs=qiime_outputs, + invocation_domain=[Invocation({}, qiime_outputs)]) diff --git a/q2_mystery_stew/generators/collections.py b/q2_mystery_stew/generators/collections.py index d05dd11..d8dd3a7 100644 --- a/q2_mystery_stew/generators/collections.py +++ b/q2_mystery_stew/generators/collections.py @@ -8,8 +8,8 @@ import itertools - -from qiime2.plugin import List, Set +from qiime2.plugin import List, Collection +from qiime2.core.type.util import is_semantic_type from q2_mystery_stew.generators.base import ParamTemplate @@ -30,22 +30,33 @@ def underpowered_set(iterable): def list_paramgen(generator): def make_list(): for param in generator: + if is_semantic_type(param.qiime_type): + view_type = param.view_type + else: + view_type = list + yield ParamTemplate( param.base_name + "_list", List[param.qiime_type], - param.view_type, + view_type, tuple(list(x) for x in underpowered_set(param.domain))) make_list.__name__ = 'list_' + generator.__name__ return make_list() -def set_paramgen(generator): - def make_set(): +def collection_paramgen(generator): + def make_collection(): for param in generator: + if is_semantic_type(param.qiime_type): + view_type = param.view_type + else: + view_type = dict + yield ParamTemplate( - param.base_name + "_set", - Set[param.qiime_type], - param.view_type, - tuple(set(x) for x in underpowered_set(param.domain))) - make_set.__name__ = 'set_' + generator.__name__ - return make_set() + param.base_name + "_collection", + Collection[param.qiime_type], + view_type, + tuple({str(k): v for k, v in enumerate(x)} + for x in underpowered_set(param.domain))) + make_collection.__name__ = 'collection_' + generator.__name__ + return make_collection() diff --git a/q2_mystery_stew/generators/typemaps.py b/q2_mystery_stew/generators/typemaps.py index 45d3401..71b84fe 100644 --- a/q2_mystery_stew/generators/typemaps.py +++ b/q2_mystery_stew/generators/typemaps.py @@ -26,7 +26,8 @@ from q2_mystery_stew.generators.artifacts import ( single_int1_1, single_int1_2, single_int2_1, single_int2_2, wrapped_int1_1, wrapped_int1_2, wrapped_int2_1, wrapped_int2_2) -from q2_mystery_stew.generators.collections import list_paramgen, set_paramgen +from q2_mystery_stew.generators.collections import (list_paramgen, + collection_paramgen) OUTPUT_STATES = [EchoOutputBranch1, EchoOutputBranch2, EchoOutputBranch3] @@ -65,7 +66,8 @@ def should_add(filter_): yield from generate_the_matrix( 'typemap_lists', [list_paramgen(x()) for x in selected_types]) yield from generate_the_matrix( - 'typemap_sets', [set_paramgen(x()) for x in selected_types]) + 'typemap_collections', [collection_paramgen(x()) + for x in selected_types]) def _to_action(factory): diff --git a/q2_mystery_stew/plugin_setup.py b/q2_mystery_stew/plugin_setup.py index 95439e0..39d87d3 100644 --- a/q2_mystery_stew/plugin_setup.py +++ b/q2_mystery_stew/plugin_setup.py @@ -23,7 +23,8 @@ from q2_mystery_stew.template import get_disguised_echo_function from q2_mystery_stew.generators import ( get_param_generators, generate_single_type_methods, - generate_multiple_output_methods, generate_typemap_methods, FILTERS) + generate_multiple_output_methods, generate_typemap_methods, + generate_output_collection_methods, FILTERS) from q2_mystery_stew.transformers import ( to_single_int_format, transform_from_metatadata, transform_to_metadata) @@ -61,6 +62,10 @@ def create_plugin(**filters): for action_template in generate_typemap_methods(filters): register_test_method(plugin, action_template) + if not filters or filters.get('output_collections', False): + for action_template in generate_output_collection_methods(): + register_test_method(plugin, action_template) + return plugin @@ -110,7 +115,7 @@ def register_test_method(plugin, action_template): function = get_disguised_echo_function(id=action_template.action_id, python_parameters=python_parameters, - num_outputs=len(qiime_outputs)) + qiime_outputs=qiime_outputs) usage_examples = {} for idx, invocation in enumerate(action_template.invocation_domain): usage_examples[f'example_{idx}'] = UsageInstantiator( diff --git a/q2_mystery_stew/template.py b/q2_mystery_stew/template.py index bc25227..e288155 100644 --- a/q2_mystery_stew/template.py +++ b/q2_mystery_stew/template.py @@ -13,8 +13,12 @@ from q2_mystery_stew.format import SingleIntFormat, EchoOutputFmt +OUTPUT_COLLECTION_SIZE = 2 +OUTPUT_COLLECTION_START = 42 +OUTPUT_COLLECTION_END = OUTPUT_COLLECTION_START + OUTPUT_COLLECTION_SIZE -def get_disguised_echo_function(id, python_parameters, num_outputs): + +def get_disguised_echo_function(id, python_parameters, qiime_outputs): TEMPLATES = [ _function_template_1output, _function_template_2output, @@ -23,8 +27,24 @@ def get_disguised_echo_function(id, python_parameters, num_outputs): _function_template_5output, ] - function = TEMPLATES[num_outputs - 1] - disguise_echo_function(function, id, python_parameters, num_outputs) + # If the first output is a Collection we check how many outputs we have + if str(qiime_outputs[0][1]) == 'Collection[EchoOutput]': + # If we only have the collection we use this template + if len(qiime_outputs) == 1: + function = _function_template_collection_only + # Otherwise, the collection is first of several, so we use this one + else: + function = _function_template_collection_first + # Now we need to check if the second argument is a collection, only the + # first or second ever will be + elif len(qiime_outputs) > 1 \ + and str(qiime_outputs[1][1]) == 'Collection[EchoOutput]': + function = _function_template_collection_second + # In all other cases, we do not need a collection output template + else: + function = TEMPLATES[len(qiime_outputs) - 1] + + disguise_echo_function(function, id, python_parameters, len(qiime_outputs)) return function @@ -55,41 +75,77 @@ def argument_to_line(name, arg): qiime2.NumericMetadataColumn)): value = arg.to_series().to_json() - # We need a list so we can jsonize it (cannot jsonize sets) - sort = False - if type(arg) is list or type(arg) is set: + if type(arg) is list: temp = [] for i in value: - # If we are given a set of artifacts it will be turned into a list - # by the framework, so we need to be ready to accept a list if isinstance(i, SingleIntFormat): temp.append(i.get_int()) expected_type = 'list' - sort = True else: temp.append(i) - # If we turned a set into a list for json purposes, we need to sort it - # to ensure it is always in the same order - if type(arg) is set or sort: - value = sorted(temp, key=repr) else: value = temp + elif type(arg) is qiime2.ResultCollection or type(arg) is dict: + temp = {} + for k, v in value.items(): + if isinstance(v, SingleIntFormat): + temp[k] = v.get_int() + expected_type = 'dict' + else: + temp[k] = v + + value = temp return json.dumps([name, value, expected_type]) + '\n' -def _echo_outputs(kwargs, num_outputs): - output = EchoOutputFmt() - with output.open() as fh: +def _echo_outputs(kwargs, num_outputs, collection_idx=None): + outputs = [] + + if collection_idx == 0: + output = _echo_collection(kwargs=kwargs, idx=0) + else: + output = _echo_single(kwargs=kwargs, idx=0) + + outputs.append(output) + + # We already handled the 1st output above + for idx in range(1, num_outputs): + if idx == collection_idx: + output = _echo_collection(idx=idx) + else: + output = _echo_single(idx=idx) + + outputs.append(output) + + return tuple(outputs) + + +def _echo_collection(kwargs=None, idx=None): + outputs = {} + + if kwargs: for name, arg in kwargs.items(): - fh.write(argument_to_line(name, arg)) + outputs[name] = _echo_single(kwargs={name: arg}) + else: + for i in range(OUTPUT_COLLECTION_START, OUTPUT_COLLECTION_END): + # Elements within a collection have a dual index, the index of the + # entire output followed by the index of the element within the + # collection + outputs[i] = _echo_single(idx=f'{idx}: {i}') + + return outputs - if num_outputs > 1: - output = (output, *map(lambda x: EchoOutputFmt(), - range(1, num_outputs))) - for idx, fmt in enumerate(output[1:], 2): - with fmt.open() as fh: - fh.write(str(idx)) + +def _echo_single(kwargs=None, idx=None): + output = EchoOutputFmt() + + with output.open() as fh: + if kwargs: + for name, arg in kwargs.items(): + fh.write(argument_to_line(name, arg)) + else: + fh.write(str(idx)) return output @@ -112,3 +168,15 @@ def _function_template_4output(**kwargs): def _function_template_5output(**kwargs): return _echo_outputs(kwargs, 5) + + +def _function_template_collection_only(**kwargs): + return _echo_outputs(kwargs, 1, 0) + + +def _function_template_collection_first(**kwargs): + return _echo_outputs(kwargs, 2, 0) + + +def _function_template_collection_second(**kwargs): + return _echo_outputs(kwargs, 2, 1) diff --git a/q2_mystery_stew/usage.py b/q2_mystery_stew/usage.py index e671eee..dbcd51b 100644 --- a/q2_mystery_stew/usage.py +++ b/q2_mystery_stew/usage.py @@ -9,10 +9,13 @@ import re import qiime2 +from qiime2.sdk import ResultCollection, Result from qiime2.sdk.util import (is_semantic_type, is_metadata_type, is_metadata_column_type) +from qiime2.sdk.usage import COLLECTION_VAR_TYPES -from q2_mystery_stew.template import argument_to_line +from q2_mystery_stew.template import ( + argument_to_line, OUTPUT_COLLECTION_START, OUTPUT_COLLECTION_END) class UsageInstantiator: @@ -46,23 +49,53 @@ def do(use_method, *args): inputs[name] = realized_arguments[name] = None elif is_semantic_type(spec.qiime_type): - if type(argument) == list or type(argument) == set: + if type(argument) is list or type(argument) is dict: collection_type = type(argument) realized_arguments[name] = collection_type() inputs[name] = collection_type() - for arg in argument: - artifact = arg() - view = artifact.view(spec.view_type) - view.__hide_from_garbage_collector = artifact - var = do(use.init_artifact, arg.__name__, arg) + if collection_type == list: + for arg in argument: + artifact = arg() + view = artifact.view(spec.view_type) + view.__hide_from_garbage_collector = artifact + var = do(use.init_artifact, arg.__name__, arg) - if collection_type == list: realized_arguments[name].append(view) inputs[name].append(var) - elif collection_type == set: - realized_arguments[name].add(view) - inputs[name].add(var) + + # we know that if we're not a list, we'll be a dict + else: + for key, arg in argument.items(): + artifact = arg() + view = artifact.view(spec.view_type) + view.__hide_from_garbage_collector = artifact + + realized_arguments[name][key] = view + + def _closure(argument): + # We need to bind the argument from the loop above + # for the factory to use the correct one. + # Otherwise the argument will always be the last + # element. + def factory(): + _input = {} + for k, v in argument.items(): + if callable(v): + v = v() + _input[k] = v + if all(isinstance(v, Result) + for v in _input.values()): + _input = ResultCollection(_input) + + return _input + return factory + # neato! + factory = _closure(argument) + + var = do(use.init_result_collection, name, factory) + inputs[name] = var + else: artifact = argument() view = artifact.view(spec.view_type) @@ -148,11 +181,33 @@ def _assert_output(self, computed_results, output_name, expected_type, idx, output = computed_results[idx] output.assert_output_type(semantic_type=expected_type) - if idx == 0: + if output.var_type in COLLECTION_VAR_TYPES: + self._assert_output_collection(output, idx, realized_arguments, + expected_type) + else: + self._assert_output_single(output, idx, realized_arguments) + + def _assert_output_collection(self, output, idx, realized_arguments, + expected_type): + inner_type = expected_type.fields[0] + for i in range(OUTPUT_COLLECTION_START, OUTPUT_COLLECTION_END): + output.assert_output_type(semantic_type=inner_type, key=i) + self._assert_output_single( + output, idx, realized_arguments, key=i, + expression=f"{idx}: {i}") + + def _assert_output_single(self, output, idx, realized_arguments, key=None, + expression=None): + if idx == 0 and realized_arguments: for name, arg in realized_arguments.items(): regex = self._fmt_regex(name, arg) output.assert_has_line_matching(path='echo.txt', - expression=regex) + expression=regex, + key=key) else: + if expression is None: + expression = str(idx) + output.assert_has_line_matching(path='echo.txt', - expression=str(idx + 1)) + expression=expression, + key=key)