Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2483 support derived types in driver #2706

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions lib/extract/netcdf/extract_netcdf_base.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,8 @@ contains
call this%PSyDataBaseType%PreStart(module_name, region_name, &
num_pre_vars, num_post_vars)
if (this%verbosity >= 1) then
write(stderr,*) "Opening ", trim(module_name//"-"//region_name//".nc")
write(stderr,*) "Opening ", trim(module_name) // "-" // &
trim(region_name) // ".nc"
endif


Expand Down Expand Up @@ -279,8 +280,8 @@ contains

call this%PSyDataBaseType%PreEndDeclaration()
if (this%verbosity >= 1) then
write(stderr,*) "Ending definition ", this%module_name // &
"-"//this%region_name//".nc"
write(stderr,*) "Ending definition ", trim(this%module_name) // &
"-"//trim(this%region_name)//".nc"
endif
retval = CheckError(nf90_enddef(this%ncid))

Expand All @@ -301,8 +302,8 @@ contains

integer :: retval
if (this%verbosity >= 1) then
write(stderr,*) "Closing ", this%module_name//"-" //&
this%region_name//".nc"
write(stderr,*) "Closing ", trim(this%module_name)//"-" //&
trim(this%region_name)//".nc"
endif
retval = CheckError(nf90_close(this%ncid))
call this%PSyDataBaseType%PostEnd()
Expand Down
25 changes: 25 additions & 0 deletions src/psyclone/core/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,31 @@ def __ge__(self, other):
f"'Signature' and '{type(other).__name__}'.")
return self._signature >= other._signature

# ------------------------------------------------------------------------
def create_reference(self, symbol):
'''Creates a reference to the signature using the specified symbol.
The reference can either be a StructureReference, or a normal
Reference (depending if the signature is a structure or not)

:param symbol: the symbol to use when creating the
(Structure)Reference
:type symbol: :py:class:`psyclone.psyir.symbols.Symbol`

:returns: a reference to the signature using the specified symbol.
:rtype: :py:class:`psyclone.psyir.nodes.Reference`

'''

# Avoid circular import (Reference uses VariableAccessInfo)
# pylint: disable=import-outside-toplevel
from psyclone.psyir.nodes import Reference, StructureReference

if self.is_structure:
ref = StructureReference.create(symbol, list(self._signature[1:]))
else:
ref = Reference(symbol)
return ref

# ------------------------------------------------------------------------
@property
def var_name(self):
Expand Down
25 changes: 19 additions & 6 deletions src/psyclone/domain/common/base_driver_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
implementations.
'''

from psyclone.core import Signature
from psyclone.psyir.nodes import Call, Literal, Reference
from psyclone.psyir.symbols import (CHARACTER_TYPE, ContainerSymbol,
ImportInterface, INTEGER_TYPE, NoType,
Expand Down Expand Up @@ -87,7 +88,13 @@ def add_call(program, name, args):
@staticmethod
def add_result_tests(program, output_symbols):
'''Adds tests to check that all output variables have the expected
value.
value. It takes a list of tuples. Each tuple contains:
1. the symbol containing the result when the kernel is called in the
driver.
2. the symbol containing the original results, i.e. the values read
from the extracted file.
3. The signature of the original access. This is used in case of
derived types to get the members used.

:param program: the program to which the tests should be added.
:type program: :py:class:`psyclone.psyir.nodes.Routine`
Expand All @@ -98,7 +105,8 @@ def add_result_tests(program, output_symbols):
values that have been read in from a file.
:type output_symbols: list[tuple[
:py:class:`psyclone.psyir.symbols.Symbol`,
:py:class:`psyclone.psyir.symbols.Symbol`]]
:py:class:`psyclone.psyir.symbols.Symbol`,
:py:class:`psyclone.core.signature.Signature`]]
'''

module = ContainerSymbol("compare_variables_mod")
Expand All @@ -114,10 +122,15 @@ def add_result_tests(program, output_symbols):

# TODO #2083: check if this can be combined with psyad result
# comparison.
for (sym_computed, sym_read) in output_symbols:
lit_name = Literal(sym_computed.name, CHARACTER_TYPE)
for (sym_computed, sym_read, signature) in output_symbols:
# First create the reference (including potential member access)
# to the newly computed value, and the value read from the file:
ref_computed = signature.create_reference(sym_computed)
ref_read = Reference(sym_read)
# Create the tag that was used in extraction:
name_in_kernel_file = Signature(sym_computed.name, signature[1:])
lit_name = Literal(str(name_in_kernel_file), CHARACTER_TYPE)
BaseDriverCreator.add_call(program, "compare",
[lit_name, Reference(sym_computed),
Reference(sym_read)])
[lit_name, ref_computed, ref_read])

BaseDriverCreator.add_call(program, "compare_summary", [])
2 changes: 1 addition & 1 deletion src/psyclone/domain/common/extract_driver_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def create_read_in_code(program, psy_data, read_write_info, postfix):
set_zero = Assignment.create(Reference(sym),
Literal("0", INTEGER_TYPE))
program.addchild(set_zero)
output_symbols.append((sym, post_sym))
output_symbols.append((sym, post_sym, signature))
return output_symbols

# -------------------------------------------------------------------------
Expand Down
87 changes: 56 additions & 31 deletions src/psyclone/domain/lfric/lfric_extract_driver_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def _add_all_kernel_symbols(self, sched, symbol_table, proxy_name_mapping,
new_symbol.replace_symbols_using(symbol_table)
reference.symbol = new_symbol

# Now handle all derived type. The name of a derived type is
# Now handle all derived types. The name of a derived type is
# 'flattened', i.e. all '%' are replaced with '_', and this is then
# declared as a non-structured type. We also need to make sure that a
# flattened name does not clash with a variable declared by the user.
Expand All @@ -381,7 +381,9 @@ def _add_all_kernel_symbols(self, sched, symbol_table, proxy_name_mapping,
proxy_name_mapping)

# Now add all non-local symbols, which need to be
# imported from the appropriate module:
# imported from the appropriate module. Note that
# this will not create the `_post ` version of the
# variables.
# -----------------------------------------------
mod_man = ModuleManager.get()
for module_name, signature in read_write_info.set_of_all_used_vars:
Expand Down Expand Up @@ -420,7 +422,7 @@ def _add_all_kernel_symbols(self, sched, symbol_table, proxy_name_mapping,

# -------------------------------------------------------------------------
@staticmethod
def _create_output_var_code(name, program, is_input, read_var,
def _create_output_var_code(signature, program, is_input, read_var,
postfix, index=None, module_name=None):
# pylint: disable=too-many-arguments
'''
Expand Down Expand Up @@ -469,31 +471,47 @@ def _create_output_var_code(name, program, is_input, read_var,
# variable and the one storing the expected results have the same
# type, look up the 'original' variable and declare the _POST variable
symbol_table = program.symbol_table
tag = signature[0]
if index:
tag = f"{tag}_{index}_data"
if module_name:
sym = symbol_table.lookup_with_tag(f"{name}@{module_name}")
else:
if index is not None:
sym = symbol_table.lookup_with_tag(f"{name}_{index}_data")
else:
# If it is not indexed then `name` will already end in "_data"
sym = symbol_table.lookup_with_tag(name)
tag = f"{tag}@{module_name}"
sym = symbol_table.lookup_with_tag(tag)

# Declare a 'post' variable of the same type and read in its value.
post_name = sym.name + postfix
post_sym = symbol_table.new_symbol(post_name,
symbol_type=DataSymbol,
datatype=sym.datatype)
if module_name:
post_tag = f"{name}{postfix}@{module_name}"
if module_name and hasattr(sym.datatype, "interface"):
flat_name = LFRicExtractDriverCreator._flatten_signature(signature)
post_name = f"{flat_name}_{module_name}{postfix}"
mod_man = ModuleManager.get()
mod_info = mod_man.get_module_info(module_name)
datatype = mod_info.get_symbol(sym.datatype.name)

if isinstance(datatype, DataTypeSymbol):
# This is a structure. We need to create a flattened name
# and fine the base type of the member involved
datatype = datatype.datatype
for member in signature[1:]:
datatype = datatype.components[member].datatype
post_sym = symbol_table.new_symbol(post_name,
symbol_type=DataSymbol,
datatype=datatype)
else:
if index is not None:
post_tag = f"{name}{postfix}%{index}"
else:
# If it is not indexed then `name` will already end in "_data"
post_tag = f"{name}{postfix}"
post_name = sym.name + postfix
post_sym = symbol_table.new_symbol(post_name,
symbol_type=DataSymbol,
datatype=sym.datatype)

# Now create the read call for the _post variable
post_tag = f"{signature}{postfix}"
if index:
post_tag = f"{post_tag}%{index}"
if module_name:
post_tag = f"{post_tag}@{module_name}"

name_lit = Literal(post_tag, CHARACTER_TYPE)
ref = Reference(post_sym)
BaseDriverCreator.add_call(program, read_var,
[name_lit, Reference(post_sym)])
[name_lit, ref])

# Now if a variable is written to, but not read, the variable
# is not allocated. So we need to allocate it and set it to 0.
Expand All @@ -506,10 +524,10 @@ def _create_output_var_code(name, program, is_input, read_var,
IntrinsicCall.Intrinsic.ALLOCATE,
[Reference(sym), ("mold", Reference(post_sym))])
program.addchild(alloc)
set_zero = Assignment.create(Reference(sym),
set_zero = Assignment.create(signature.create_reference(sym),
Literal("0", INTEGER_TYPE))
program.addchild(set_zero)
return (sym, post_sym)
return (sym, post_sym, signature)

# -------------------------------------------------------------------------
def _create_read_in_code(self, program, psy_data, original_symbol_table,
Expand Down Expand Up @@ -615,12 +633,19 @@ def _sym_is_field(sym):
# because we couldn't successfully parse the module)
# and will have inconsistent/missing declarations.
continue
name_lit = Literal(tag, CHARACTER_TYPE)
name_lit = Literal(f"{signature}@{module_name}",
CHARACTER_TYPE)
else:
sym = symbol_table.lookup_with_tag(str(signature))
name_lit = Literal(str(signature), CHARACTER_TYPE)

try:
ref = signature.create_reference(sym)
except TypeError:
ref = Reference(sym)

self.add_call(program, read_var,
[name_lit, Reference(sym)])
[name_lit, ref])

# Then handle all variables that are written (note that some
# variables might be read and written)
Expand All @@ -632,7 +657,6 @@ def _sym_is_field(sym):
# file. The content of these two variables should be identical
# at the end.
output_symbols = []

for module_name, signature in read_write_info.write_list:
# Find the right symbol for the variable. Note that all variables
# in the input and output list have been detected as being used
Expand All @@ -657,15 +681,16 @@ def _sym_is_field(sym):
upper = int(orig_sym.datatype.shape[0].upper.value)
for i in range(1, upper+1):
sym_tuple = \
self._create_output_var_code(flattened, program,
is_input, read_var,
postfix, index=i,
self._create_output_var_code(Signature(flattened),
program, is_input,
read_var, postfix,
index=i,
module_name=module_name)
output_symbols.append(sym_tuple)
else:
sig_str = str(signature)
sym_tuple = \
self._create_output_var_code(str(signature), program,
self._create_output_var_code(signature, program,
is_input, read_var, postfix,
module_name=module_name)
output_symbols.append(sym_tuple)
Expand Down
12 changes: 7 additions & 5 deletions src/psyclone/parse/module_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@


from collections import OrderedDict
import copy
from difflib import SequenceMatcher
import os
import re
Expand Down Expand Up @@ -93,7 +92,7 @@ def __init__(self):
# Setup the regex used to find Fortran modules. Have to be careful not
# to match e.g. "module procedure :: some_sub".
self._module_pattern = re.compile(r"^\s*module\s+([a-z]\S*)\s*$",
flags=(re.IGNORECASE | re.MULTILINE))
flags=re.IGNORECASE | re.MULTILINE)

# ------------------------------------------------------------------------
def add_search_path(self, directories, recursive=True):
Expand Down Expand Up @@ -378,8 +377,11 @@ def sort_modules(self, module_dependencies):
'''
result = []

# Create a copy to avoid modifying the callers data structure:
todo = copy.deepcopy(module_dependencies)
# Create a copy to avoid modifying the callers data structure, and
# also make sure all dependencies are in lower case
todo = {}
for module, dependencies in module_dependencies.items():
todo[module.lower()] = set(i.lower() for i in dependencies)

# Consistency check: test that all dependencies listed are also
# a key in the list, otherwise there will be a dependency that
Expand Down Expand Up @@ -415,7 +417,7 @@ def sort_modules(self, module_dependencies):
# dependencies, the best we can do in this case - and
# it's better to provide all modules (even if they cannot)
# be sorted, than missing some.
all_mods_sorted = sorted((mod for mod in todo.keys()),
all_mods_sorted = sorted((mod for mod in todo),
key=lambda x: len(todo[x]))
mod = all_mods_sorted[0]

Expand Down
32 changes: 29 additions & 3 deletions src/psyclone/tests/core/signature_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,16 @@

'''This module tests the Signature class.'''

from __future__ import absolute_import
import pytest

from psyclone.core import ComponentIndices, Signature
from psyclone.errors import InternalError
from psyclone.psyir.backend.c import CWriter
from psyclone.psyir.backend.fortran import FortranWriter
from psyclone.psyir.nodes import Reference
from psyclone.psyir.symbols import DataSymbol, INTEGER_SINGLE_TYPE
from psyclone.psyir.nodes import Reference, StructureReference
from psyclone.psyir.symbols import (DataSymbol, DataTypeSymbol,
INTEGER_SINGLE_TYPE,
StructureType, Symbol)


def test_signature():
Expand Down Expand Up @@ -239,3 +240,28 @@ def test_output_languages():
assert sig.to_language(comp) == "a(1)%b%c(i,j)"
assert sig.to_language(comp, f_writer) == "a(1)%b%c(i,j)"
assert sig.to_language(comp, c_writer) == "a[1].b.c[i + j * cLEN1]"


def test_create_reference(fortran_writer):
'''Tests the create_reference function.
'''

# First define a structure type:
grid_type = StructureType.create([("nx", INTEGER_SINGLE_TYPE,
Symbol.Visibility.PUBLIC, None)])
grid_type_symbol = DataTypeSymbol("grid_type", grid_type)
symbol = DataSymbol("a", grid_type_symbol)

sig = Signature("a")

# Test a non-structure reference:
ref = sig.create_reference(symbol)
# pylint: disable=unidiomatic-typecheck
assert type(ref) is Reference
assert fortran_writer(ref) == "a"

# Test a structure reference:
sig = Signature(["a", "b"])
struct_ref = sig.create_reference(symbol)
assert type(struct_ref) is StructureReference
assert fortran_writer(struct_ref) == "a%b"
4 changes: 3 additions & 1 deletion src/psyclone/tests/domain/common/base_driver_creator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

import pytest

from psyclone.core import Signature
from psyclone.domain.common import BaseDriverCreator
from psyclone.psyir.nodes import Literal, Routine
from psyclone.psyir.symbols import DataSymbol, INTEGER_TYPE, RoutineSymbol
Expand Down Expand Up @@ -78,7 +79,8 @@ def test_lfric_driver_add_result_tests(fortran_writer):
"a1_orig", symbol_type=DataSymbol, datatype=INTEGER_TYPE)
# This will add one test for the variable a1 with the
# correct values a1_orig.
BaseDriverCreator.add_result_tests(program, [(a1, a1_orig)])
BaseDriverCreator.add_result_tests(program,
[(a1, a1_orig, Signature("a"))])
out = fortran_writer(program)
expected = """ call compare_init(1)
call compare('a1', a1, a1_orig)
Expand Down
Loading
Loading