diff --git a/opshin/tests/test_keywords.py b/opshin/tests/test_keywords.py index 91957733..271498d3 100644 --- a/opshin/tests/test_keywords.py +++ b/opshin/tests/test_keywords.py @@ -96,3 +96,34 @@ def validator(a: int, b: int, c: int) -> int: """ ret = eval_uplc_value(source_code, x, y, z) self.assertEqual(ret, (x - z) * y) + + @given(x=st.integers(), y=st.integers(), z=st.integers()) + def test_default(self, x: int, y: int, z: int): + source_code = f""" +def simple_example(x: int, y: int, z: int={z}) -> int: + return (x-z)*y + +def validator(a: int, b: int) -> int: + return simple_example(a, b) +""" + ret = eval_uplc_value(source_code, x, y) + self.assertEqual(ret, (x - z) * y) + + def test_default_wrong_type(self): + source_code = f""" +def simple_example(x: int, y: int, z: int="hello") -> int: + return (x-z)*y + +def validator(a: int, b: int) -> int: + return simple_example(a, b) +""" + with self.assertRaises(Exception): + ret = eval_uplc_value(source_code, 1, 2) + + def test_no_allow_validator_default(self): + source_code = f""" +def validator(a: int, b: int, c:int=1) -> int: + return a*b*c +""" + with self.assertRaises(Exception): + ret = eval_uplc_value(source_code, 1, 2, 2) diff --git a/opshin/type_inference.py b/opshin/type_inference.py index 5a10d064..21de5727 100644 --- a/opshin/type_inference.py +++ b/opshin/type_inference.py @@ -24,6 +24,14 @@ # from frozendict import frozendict +def pad_left(a: list, b: list, filler=None): + if len(a) > len(b): + return a, [filler for _ in range(len(a) - len(b))] + b + if len(b) > len(a): + return [filler for _ in range(len(b) - len(a))] + a, b + return a, b + + INITIAL_SCOPE = { # class annotations "bytes": ByteStringType(), @@ -244,6 +252,7 @@ class AggressiveTypeInferencer(CompilingNodeTransformer): def __init__(self, allow_isinstance_anything=False): self.allow_isinstance_anything = allow_isinstance_anything self.FUNCTION_ARGUMENT_REGISTRY = {} + self.FUNCTION_DEFAULT_REGISTRY = {} # A stack of dictionaries for storing scoped knowledge of variable types self.scopes = [INITIAL_SCOPE] @@ -561,12 +570,18 @@ def visit_arg(self, node: arg) -> typedarg: return ta def visit_arguments(self, node: arguments) -> typedarguments: - if node.kw_defaults or node.kwarg or node.kwonlyargs or node.defaults: + if node.kw_defaults or node.kwarg or node.kwonlyargs: raise NotImplementedError( "Keyword arguments and defaults not supported yet" ) ta = copy(node) ta.args = [self.visit(a) for a in node.args] + ta.defaults = [self.visit(d) for d in node.defaults] + # defaults match last k arguments + for i, (a, d) in enumerate(zip(reversed(ta.args), reversed(ta.defaults))): + assert ( + a.typ >= d.typ + ), f'Default value must be compatible with argument in argument "{a.orig_arg}" (position {len(ta.args)-i}). Expected {a.typ}, got {d.typ}' return ta def visit_FunctionDef(self, node: FunctionDef) -> TypedFunctionDef: @@ -604,7 +619,8 @@ def visit_FunctionDef(self, node: FunctionDef) -> TypedFunctionDef: self.exit_scope() # We need the function type outside for usage self.set_variable_type(node.name, tfd.typ) - self.FUNCTION_ARGUMENT_REGISTRY[node.name] = node.args.args + self.FUNCTION_ARGUMENT_REGISTRY[node.name] = tfd.args.args + self.FUNCTION_DEFAULT_REGISTRY[node.name] = tfd.args.defaults return tfd def visit_Module(self, node: Module) -> TypedModule: @@ -776,8 +792,9 @@ def visit_Call(self, node: Call) -> TypedCall: ), "Keyword arguments can only be used with user defined functions" keywords = copy(node.keywords) reg_args = self.FUNCTION_ARGUMENT_REGISTRY[node.func.id] + reg_defs = self.FUNCTION_DEFAULT_REGISTRY[node.func.id] args = [] - for i, a in enumerate(reg_args): + for i, (a, d) in enumerate(zip(*pad_left(reg_args, reg_defs))): if len(node.args) > i: args.append(self.visit(node.args[i])) else: @@ -786,11 +803,19 @@ def visit_Call(self, node: Call) -> TypedCall: for idx, keyword in enumerate(keywords) if keyword.arg == a.orig_arg ] - assert ( - len(candidates) == 1 - ), f"There should be one keyword or positional argument for the arg {a.orig_arg} but found {len(candidates)}" - args.append(self.visit(candidates[0][1].value)) - keywords.pop(candidates[0][0]) + if candidates: + assert ( + len(candidates) == 1 + ), f'There should be one keyword or positional argument for the argument "{a.orig_arg}" (position {i+1}) but found {len(candidates)}' + candidate = self.visit(candidates[0][1].value) + keywords.pop(candidates[0][0]) + elif d is not None: + candidate = copy(d) + else: + raise AssertionError( + f'Could not find argument for argument "{a.orig_arg}" (position {i+1})' + ) + args.append(candidate) assert ( len(keywords) == 0 ), f"Could not match the keywords {[keyword.arg for keyword in keywords]} to any argument"