Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Begin implementing default arguments #388

Draft
wants to merge 1 commit into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions opshin/tests/test_keywords.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,34 @@ def validator(a: int, b: int, c: int) -> int:
"""
ret = eval_uplc_value(source_code, x, y, z)
self.assertEqual(ret, (x - z) * y)

@given(x=st.integers(), y=st.integers(), z=st.integers())
def test_default(self, x: int, y: int, z: int):
source_code = f"""
def simple_example(x: int, y: int, z: int={z}) -> int:
return (x-z)*y

def validator(a: int, b: int) -> int:
return simple_example(a, b)
"""
ret = eval_uplc_value(source_code, x, y)
self.assertEqual(ret, (x - z) * y)

def test_default_wrong_type(self):
source_code = f"""
def simple_example(x: int, y: int, z: int="hello") -> int:
return (x-z)*y

def validator(a: int, b: int) -> int:
return simple_example(a, b)
"""
with self.assertRaises(Exception):
ret = eval_uplc_value(source_code, 1, 2)

def test_no_allow_validator_default(self):
source_code = f"""
def validator(a: int, b: int, c:int=1) -> int:
return a*b*c
"""
with self.assertRaises(Exception):
ret = eval_uplc_value(source_code, 1, 2, 2)
41 changes: 33 additions & 8 deletions opshin/type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@
# from frozendict import frozendict


def pad_left(a: list, b: list, filler=None):
if len(a) > len(b):
return a, [filler for _ in range(len(a) - len(b))] + b
if len(b) > len(a):
return [filler for _ in range(len(b) - len(a))] + a, b
return a, b


INITIAL_SCOPE = {
# class annotations
"bytes": ByteStringType(),
Expand Down Expand Up @@ -244,6 +252,7 @@ class AggressiveTypeInferencer(CompilingNodeTransformer):
def __init__(self, allow_isinstance_anything=False):
self.allow_isinstance_anything = allow_isinstance_anything
self.FUNCTION_ARGUMENT_REGISTRY = {}
self.FUNCTION_DEFAULT_REGISTRY = {}

# A stack of dictionaries for storing scoped knowledge of variable types
self.scopes = [INITIAL_SCOPE]
Expand Down Expand Up @@ -561,12 +570,18 @@ def visit_arg(self, node: arg) -> typedarg:
return ta

def visit_arguments(self, node: arguments) -> typedarguments:
if node.kw_defaults or node.kwarg or node.kwonlyargs or node.defaults:
if node.kw_defaults or node.kwarg or node.kwonlyargs:
raise NotImplementedError(
"Keyword arguments and defaults not supported yet"
)
ta = copy(node)
ta.args = [self.visit(a) for a in node.args]
ta.defaults = [self.visit(d) for d in node.defaults]
# defaults match last k arguments
for i, (a, d) in enumerate(zip(reversed(ta.args), reversed(ta.defaults))):
assert (
a.typ >= d.typ
), f'Default value must be compatible with argument in argument "{a.orig_arg}" (position {len(ta.args)-i}). Expected {a.typ}, got {d.typ}'
return ta

def visit_FunctionDef(self, node: FunctionDef) -> TypedFunctionDef:
Expand Down Expand Up @@ -604,7 +619,8 @@ def visit_FunctionDef(self, node: FunctionDef) -> TypedFunctionDef:
self.exit_scope()
# We need the function type outside for usage
self.set_variable_type(node.name, tfd.typ)
self.FUNCTION_ARGUMENT_REGISTRY[node.name] = node.args.args
self.FUNCTION_ARGUMENT_REGISTRY[node.name] = tfd.args.args
self.FUNCTION_DEFAULT_REGISTRY[node.name] = tfd.args.defaults
return tfd

def visit_Module(self, node: Module) -> TypedModule:
Expand Down Expand Up @@ -776,8 +792,9 @@ def visit_Call(self, node: Call) -> TypedCall:
), "Keyword arguments can only be used with user defined functions"
keywords = copy(node.keywords)
reg_args = self.FUNCTION_ARGUMENT_REGISTRY[node.func.id]
reg_defs = self.FUNCTION_DEFAULT_REGISTRY[node.func.id]
args = []
for i, a in enumerate(reg_args):
for i, (a, d) in enumerate(zip(*pad_left(reg_args, reg_defs))):
if len(node.args) > i:
args.append(self.visit(node.args[i]))
else:
Expand All @@ -786,11 +803,19 @@ def visit_Call(self, node: Call) -> TypedCall:
for idx, keyword in enumerate(keywords)
if keyword.arg == a.orig_arg
]
assert (
len(candidates) == 1
), f"There should be one keyword or positional argument for the arg {a.orig_arg} but found {len(candidates)}"
args.append(self.visit(candidates[0][1].value))
keywords.pop(candidates[0][0])
if candidates:
assert (
len(candidates) == 1
), f'There should be one keyword or positional argument for the argument "{a.orig_arg}" (position {i+1}) but found {len(candidates)}'
candidate = self.visit(candidates[0][1].value)
keywords.pop(candidates[0][0])
elif d is not None:
candidate = copy(d)
else:
raise AssertionError(
f'Could not find argument for argument "{a.orig_arg}" (position {i+1})'
)
args.append(candidate)
assert (
len(keywords) == 0
), f"Could not match the keywords {[keyword.arg for keyword in keywords]} to any argument"
Expand Down