Skip to content

Commit

Permalink
Swifty dropbox merge beta (#349)
Browse files Browse the repository at this point in the history
Co-authored-by: Taylor Case <[email protected]>
  • Loading branch information
julianlocke and Taylor Case authored Aug 28, 2024
1 parent 1cdb8c0 commit 40bd164
Show file tree
Hide file tree
Showing 20 changed files with 1,962 additions and 533 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/CI.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
strategy:
matrix:
os: [macos-latest, ubuntu-20.04]
python-version: [3.6, 3.8, 3.11, pypy-3.7]
python-version: [3.8, 3.11, pypy-3.8, pypy-3.10]
steps:
- uses: actions/checkout@v2
- name: Setup Python environment
Expand All @@ -35,5 +35,5 @@ jobs:
- name: Run MyPy
if: matrix.python-version != 'pypy-3.7'
run: |
pip install enum34 mypy typed-ast types-six
pip install enum34 mypy types-six
./mypy-run.sh
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
ply>= 3.4
six>= 1.12.0
packaging>=21.0
Jinja2>= 3.0.3
9 changes: 6 additions & 3 deletions stone/backends/python_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,10 +640,12 @@ def _generate_custom_annotation_processors(self, ns, data_type, extra_annotation
dt, _, _ = unwrap(data_type)
if is_struct_type(dt) or is_union_type(dt):
annotation_types_seen = set()
# If data type enumerates subtypes, recurse to subtypes instead which in turn collect parents' custom annotations
# If data type enumerates subtypes, recurse to subtypes instead which in turn collect
# parents' custom annotations
if is_struct_type(dt) and dt.has_enumerated_subtypes():
for subtype in dt.get_enumerated_subtypes():
for annotation_type, recursive_processor in self._generate_custom_annotation_processors(ns, subtype.data_type):
processors = self._generate_custom_annotation_processors(ns, subtype.data_type)
for annotation_type, recursive_processor in processors:
if annotation_type not in annotation_types_seen:
yield (annotation_type, recursive_processor)
annotation_types_seen.add(annotation_type)
Expand All @@ -653,7 +655,8 @@ def _generate_custom_annotation_processors(self, ns, data_type, extra_annotation
yield (annotation.annotation_type,
generate_func_call(
'bb.make_struct_annotation_processor',
args=[class_name_for_annotation_type(annotation.annotation_type, ns),
args=[class_name_for_annotation_type(annotation.annotation_type,
ns),
'processor']
))
annotation_types_seen.add(annotation.annotation_type)
Expand Down
163 changes: 143 additions & 20 deletions stone/backends/swift.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,37 @@
from contextlib import contextmanager

import os
from stone.backend import CodeBackend
from stone.backends.swift_helpers import (
fmt_class,
fmt_func,
fmt_obj,
fmt_type,
fmt_var,
fmt_objc_type,
mapped_list_info,
)

from stone.ir import (
Boolean,
Bytes,
DataType,
Float32,
Float64,
Int32,
Int64,
List,
Map,
String,
Timestamp,
UInt32,
UInt64,
Void,
is_list_type,
is_map_type,
is_timestamp_type,
is_union_type,
is_user_defined_type,
unwrap_nullable,
is_nullable_type,
)

_serial_type_table = {
Expand All @@ -37,13 +42,29 @@
Int32: 'Int32Serializer',
Int64: 'Int64Serializer',
List: 'ArraySerializer',
Map: 'DictionarySerializer',
String: 'StringSerializer',
Timestamp: 'NSDateSerializer',
UInt32: 'UInt32Serializer',
UInt64: 'UInt64Serializer',
Void: 'VoidSerializer',
}

_nsnumber_type_table = {
Boolean: '.boolValue',
Bytes: '',
Float32: '.floatValue',
Float64: '.doubleValue',
Int32: '.int32Value',
Int64: '.int64Value',
List: '',
String: '',
Timestamp: '',
UInt32: '.uint32Value',
UInt64: '.uint64Value',
Void: '',
Map: '',
}

stone_warning = """\
///
Expand Down Expand Up @@ -98,24 +119,6 @@ def _func_args(self, args_list, newlines=False, force_first=False, not_init=Fals
sep += '\n' + self.make_indent()
return sep.join(out)

@contextmanager
def class_block(self, thing, protocols=None):
protocols = protocols or []
extensions = []

if isinstance(thing, DataType):
name = fmt_class(thing.name)
if thing.parent_type:
extensions.append(fmt_type(thing.parent_type))
else:
name = thing
extensions.extend(protocols)

extend_suffix = ': {}'.format(', '.join(extensions)) if extensions else ''

with self.block('open class {}{}'.format(name, extend_suffix)):
yield

def _struct_init_args(self, data_type, namespace=None): # pylint: disable=unused-argument
args = []
for field in data_type.all_fields:
Expand All @@ -135,6 +138,113 @@ def _struct_init_args(self, data_type, namespace=None): # pylint: disable=unuse
args.append(arg)
return args

def _objc_init_args(self, data_type, include_defaults=True):
args = []
for field in data_type.all_fields:
name = fmt_var(field.name)
value = fmt_objc_type(field.data_type)
data_type, nullable = unwrap_nullable(field.data_type)

if not include_defaults and (field.has_default or nullable):
continue

arg = (name, value)
args.append(arg)
return args

def _objc_no_defualts_func_args(self, data_type, args_data=None):
args = []
for field in data_type.all_fields:
name = fmt_var(field.name)
_, nullable = unwrap_nullable(field.data_type)
if field.has_default or nullable:
continue
arg = (name, name)
args.append(arg)

if args_data is not None:
_, type_data_list = tuple(args_data)
extra_args = [tuple(type_data[:-1]) for type_data in type_data_list]
for name, _, extra_type in extra_args:
if not is_nullable_type(extra_type):
arg = (name, name)
args.append(arg)

return self._func_args(args)

def _objc_init_args_to_swift(self, data_type, args_data=None, include_defaults=True):
args = []
for field in data_type.all_fields:
name = fmt_var(field.name)
field_data_type, nullable = unwrap_nullable(field.data_type)
if not include_defaults and (field.has_default or nullable):
continue
nsnumber_type = _nsnumber_type_table.get(field_data_type.__class__)
value = '{}{}{}'.format(name,
'?' if nullable and nsnumber_type else '',
nsnumber_type)
if is_list_type(field_data_type):
_, prefix, suffix, list_data_type, _ = mapped_list_info(field_data_type)

value = '{}{}'.format(name,
'?' if nullable else '')
list_nsnumber_type = _nsnumber_type_table.get(list_data_type.__class__)

if not is_user_defined_type(list_data_type) and not list_nsnumber_type:
value = name
else:
value = '{}.map {}'.format(value,
prefix)

if is_user_defined_type(list_data_type):
value = '{}{{ $0.{} }}'.format(value,
self._objc_swift_var_name(list_data_type))
else:
value = '{}{{ $0{} }}'.format(value,
list_nsnumber_type)

value = '{}{}'.format(value,
suffix)
elif is_map_type(field_data_type):
if is_user_defined_type(field_data_type.value_data_type):
value = '{}{}.mapValues {{ $0.swift }}'.format(name,
'?' if nullable else '')
elif is_user_defined_type(field_data_type):
value = '{}{}.{}'.format(name,
'?' if nullable else '',
self._objc_swift_var_name(field_data_type))

arg = (name, value)
args.append(arg)

if args_data is not None:
_, type_data_list = tuple(args_data)
extra_args = [tuple(type_data[:-1]) for type_data in type_data_list]
for name, _, _ in extra_args:
args.append((name, name))

return self._func_args(args)

def _objc_swift_var_name(self, data_type):
parent_type = data_type.parent_type
uw_parent_type, _ = unwrap_nullable(parent_type)
sub_count = 1 if parent_type else 0
while is_user_defined_type(uw_parent_type) and parent_type.parent_type:
sub_count += 1
parent_type = parent_type.parent_type
uw_parent_type, _ = unwrap_nullable(parent_type)

if sub_count == 0 or is_union_type(data_type):
return 'swift'
else:
name = 'Swift'
i = 1
while i <= sub_count:
name = '{}{}'.format('sub' if i == sub_count else 'Sub',
name)
i += 1
return name

def _docf(self, tag, val):
if tag == 'route':
if ':' in val:
Expand All @@ -155,6 +265,14 @@ def _docf(self, tag, val):
else:
return val

def _write_output_in_target_folder(self, output, file_name):
full_path = self.target_folder_path
if not os.path.exists(full_path):
os.mkdir(full_path)
full_path = os.path.join(full_path, file_name)
with open(full_path, "w", encoding='utf-8') as fh:
fh.write(output)

def fmt_serial_type(data_type):
data_type, nullable = unwrap_nullable(data_type)

Expand All @@ -167,6 +285,9 @@ def fmt_serial_type(data_type):

if is_list_type(data_type):
result = result + '<{}>'.format(fmt_serial_type(data_type.data_type))
if is_map_type(data_type):
result = result + '<{}, {}>'.format(fmt_serial_type(data_type.key_data_type),
fmt_serial_type(data_type.value_data_type))

return result if not nullable else 'NullableSerializer'

Expand All @@ -183,6 +304,8 @@ def fmt_serial_obj(data_type):

if is_list_type(data_type):
result = result + '({})'.format(fmt_serial_obj(data_type.data_type))
elif is_map_type(data_type):
result = result + '({})'.format(fmt_serial_obj(data_type.value_data_type))
elif is_timestamp_type(data_type):
result = result + '("{}")'.format(data_type.format)
else:
Expand Down
Loading

0 comments on commit 40bd164

Please sign in to comment.